Skip to content

Commit

Permalink
[CG-10837] feat: Linear tools error handling, extra test, request ret…
Browse files Browse the repository at this point in the history
…ry (#653)

# Motivation

<!-- Why is this change necessary? -->

# Content

<!-- Please include a summary of the change -->

# Testing

<!-- How was the change tested? -->

# Please check the following before marking your PR as ready for review

- [ ] I have added tests for my changes
- [ ] I have updated the documentation or added new documentation as
needed

---------

Co-authored-by: tomcodgen <[email protected]>
  • Loading branch information
tomcodgen and tomcodgen authored Feb 25, 2025
1 parent f26ed85 commit ac86411
Show file tree
Hide file tree
Showing 5 changed files with 229 additions and 14 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ dependencies = [
"lox>=0.12.0",
"httpx>=0.28.1",
"docker>=6.1.3",
"urllib3>=2.0.0",
]

license = { text = "Apache-2.0" }
Expand Down
43 changes: 29 additions & 14 deletions src/codegen/extensions/linear/linear_client.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import json
import logging
import os
from typing import Optional

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

from codegen.extensions.linear.types import LinearComment, LinearIssue, LinearTeam, LinearUser

Expand All @@ -14,7 +15,7 @@ class LinearClient:
api_headers: dict
api_endpoint = "https://api.linear.app/graphql"

def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] = None):
def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] = None, max_retries: int = 3, backoff_factor: float = 0.5):
if not access_token:
access_token = os.getenv("LINEAR_ACCESS_TOKEN")
if not access_token:
Expand All @@ -31,6 +32,18 @@ def __init__(self, access_token: Optional[str] = None, team_id: Optional[str] =
"Authorization": self.access_token,
}

# Set up a session with retry logic
self.session = requests.Session()
retry_strategy = Retry(
total=max_retries,
backoff_factor=backoff_factor,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["POST", "GET"], # POST is important for GraphQL
)
adapter = HTTPAdapter(max_retries=retry_strategy)
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)

def get_issue(self, issue_id: str) -> LinearIssue:
query = """
query getIssue($issueId: String!) {
Expand All @@ -42,7 +55,7 @@ def get_issue(self, issue_id: str) -> LinearIssue:
}
"""
variables = {"issueId": issue_id}
response = requests.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables})
response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables})
data = response.json()
issue_data = data["data"]["issue"]
return LinearIssue(id=issue_data["id"], title=issue_data["title"], description=issue_data["description"])
Expand All @@ -66,7 +79,7 @@ def get_issue_comments(self, issue_id: str) -> list[LinearComment]:
}
"""
variables = {"issueId": issue_id}
response = requests.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables})
response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": query, "variables": variables})
data = response.json()
comments = data["data"]["issue"]["comments"]["nodes"]

Expand All @@ -80,8 +93,8 @@ def get_issue_comments(self, issue_id: str) -> list[LinearComment]:
# Convert raw comments to LinearComment objects
return parsed_comments

def comment_on_issue(self, issue_id: str, body: str) -> dict:
"""issue_id is our internal issue ID"""
def comment_on_issue(self, issue_id: str, body: str) -> LinearComment:
"""Add a comment to an issue."""
query = """mutation makeComment($issueId: String!, $body: String!) {
commentCreate(input: {issueId: $issueId, body: $body}) {
comment {
Expand All @@ -97,19 +110,21 @@ def comment_on_issue(self, issue_id: str, body: str) -> dict:
}
"""
variables = {"issueId": issue_id, "body": body}
response = requests.post(
response = self.session.post(
self.api_endpoint,
headers=self.api_headers,
data=json.dumps({"query": query, "variables": variables}),
json={"query": query, "variables": variables},
)
data = response.json()
try:
comment_data = data["data"]["commentCreate"]["comment"]
user_data = comment_data.get("user", None)
user = LinearUser(id=user_data["id"], name=user_data["name"]) if user_data else None

return comment_data
return LinearComment(id=comment_data["id"], body=comment_data["body"], user=user)
except Exception as e:
msg = f"Error creating comment\n{data}\n{e}"
raise Exception(msg)
raise ValueError(msg)

def register_webhook(self, webhook_url: str, team_id: str, secret: str, enabled: bool, resource_types: list[str]):
mutation = """
Expand All @@ -134,7 +149,7 @@ def register_webhook(self, webhook_url: str, team_id: str, secret: str, enabled:
}
}

response = requests.post(self.api_endpoint, headers=self.api_headers, json={"query": mutation, "variables": variables})
response = self.session.post(self.api_endpoint, headers=self.api_headers, json={"query": mutation, "variables": variables})
body = response.json()
return body

Expand All @@ -160,7 +175,7 @@ def search_issues(self, query: str, limit: int = 10) -> list[LinearIssue]:
}
"""
variables = {"query": query, "limit": limit}
response = requests.post(
response = self.session.post(
self.api_endpoint,
headers=self.api_headers,
json={"query": graphql_query, "variables": variables},
Expand Down Expand Up @@ -222,7 +237,7 @@ def create_issue(self, title: str, description: str | None = None, team_id: str
}
}

response = requests.post(
response = self.session.post(
self.api_endpoint,
headers=self.api_headers,
json={"query": mutation, "variables": variables},
Expand Down Expand Up @@ -258,7 +273,7 @@ def get_teams(self) -> list[LinearTeam]:
}
"""

response = requests.post(
response = self.session.post(
self.api_endpoint,
headers=self.api_headers,
json={"query": query},
Expand Down
179 changes: 179 additions & 0 deletions src/codegen/extensions/tools/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import ClassVar

import requests
from pydantic import Field

from codegen.extensions.linear.linear_client import LinearClient
Expand Down Expand Up @@ -97,7 +98,32 @@ def linear_get_issue_tool(client: LinearClient, issue_id: str) -> LinearIssueObs
issue_id=issue_id,
issue_data=issue.dict(),
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearIssueObservation(
status="error",
error=f"Network error when fetching issue: {e!s}",
issue_id=issue_id,
issue_data={},
)
except ValueError as e:
# Input validation errors
return LinearIssueObservation(
status="error",
error=f"Invalid input: {e!s}",
issue_id=issue_id,
issue_data={},
)
except KeyError as e:
# Missing data in response
return LinearIssueObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
issue_id=issue_id,
issue_data={},
)
except Exception as e:
# Catch-all for other errors
return LinearIssueObservation(
status="error",
error=f"Failed to get issue: {e!s}",
Expand All @@ -115,7 +141,32 @@ def linear_get_issue_comments_tool(client: LinearClient, issue_id: str) -> Linea
issue_id=issue_id,
comments=[comment.dict() for comment in comments],
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearCommentsObservation(
status="error",
error=f"Network error when fetching comments: {e!s}",
issue_id=issue_id,
comments=[],
)
except ValueError as e:
# Input validation errors
return LinearCommentsObservation(
status="error",
error=f"Invalid input: {e!s}",
issue_id=issue_id,
comments=[],
)
except KeyError as e:
# Missing data in response
return LinearCommentsObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
issue_id=issue_id,
comments=[],
)
except Exception as e:
# Catch-all for other errors
return LinearCommentsObservation(
status="error",
error=f"Failed to get issue comments: {e!s}",
Expand All @@ -133,7 +184,32 @@ def linear_comment_on_issue_tool(client: LinearClient, issue_id: str, body: str)
issue_id=issue_id,
comment=comment,
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearCommentObservation(
status="error",
error=f"Network error when adding comment: {e!s}",
issue_id=issue_id,
comment={},
)
except ValueError as e:
# Input validation errors
return LinearCommentObservation(
status="error",
error=f"Invalid input: {e!s}",
issue_id=issue_id,
comment={},
)
except KeyError as e:
# Missing data in response
return LinearCommentObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
issue_id=issue_id,
comment={},
)
except Exception as e:
# Catch-all for other errors
return LinearCommentObservation(
status="error",
error=f"Failed to comment on issue: {e!s}",
Expand All @@ -159,7 +235,35 @@ def linear_register_webhook_tool(
team_id=team_id,
response=response,
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearWebhookObservation(
status="error",
error=f"Network error when registering webhook: {e!s}",
webhook_url=webhook_url,
team_id=team_id,
response={},
)
except ValueError as e:
# Input validation errors
return LinearWebhookObservation(
status="error",
error=f"Invalid input: {e!s}",
webhook_url=webhook_url,
team_id=team_id,
response={},
)
except KeyError as e:
# Missing data in response
return LinearWebhookObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
webhook_url=webhook_url,
team_id=team_id,
response={},
)
except Exception as e:
# Catch-all for other errors
return LinearWebhookObservation(
status="error",
error=f"Failed to register webhook: {e!s}",
Expand All @@ -178,7 +282,32 @@ def linear_search_issues_tool(client: LinearClient, query: str, limit: int = 10)
query=query,
issues=[issue.dict() for issue in issues],
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearSearchObservation(
status="error",
error=f"Network error when searching issues: {e!s}",
query=query,
issues=[],
)
except ValueError as e:
# Input validation errors
return LinearSearchObservation(
status="error",
error=f"Invalid input: {e!s}",
query=query,
issues=[],
)
except KeyError as e:
# Missing data in response
return LinearSearchObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
query=query,
issues=[],
)
except Exception as e:
# Catch-all for other errors
return LinearSearchObservation(
status="error",
error=f"Failed to search issues: {e!s}",
Expand All @@ -197,7 +326,35 @@ def linear_create_issue_tool(client: LinearClient, title: str, description: str
team_id=team_id,
issue_data=issue.dict(),
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearCreateIssueObservation(
status="error",
error=f"Network error when creating issue: {e!s}",
title=title,
team_id=team_id,
issue_data={},
)
except ValueError as e:
# Input validation errors
return LinearCreateIssueObservation(
status="error",
error=f"Invalid input: {e!s}",
title=title,
team_id=team_id,
issue_data={},
)
except KeyError as e:
# Missing data in response
return LinearCreateIssueObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
title=title,
team_id=team_id,
issue_data={},
)
except Exception as e:
# Catch-all for other errors
return LinearCreateIssueObservation(
status="error",
error=f"Failed to create issue: {e!s}",
Expand All @@ -215,7 +372,29 @@ def linear_get_teams_tool(client: LinearClient) -> LinearTeamsObservation:
status="success",
teams=[team.dict() for team in teams],
)
except requests.exceptions.RequestException as e:
# Network-related errors
return LinearTeamsObservation(
status="error",
error=f"Network error when fetching teams: {e!s}",
teams=[],
)
except ValueError as e:
# Input validation errors
return LinearTeamsObservation(
status="error",
error=f"Invalid input: {e!s}",
teams=[],
)
except KeyError as e:
# Missing data in response
return LinearTeamsObservation(
status="error",
error=f"Unexpected API response format: {e!s}",
teams=[],
)
except Exception as e:
# Catch-all for other errors
return LinearTeamsObservation(
status="error",
error=f"Failed to get teams: {e!s}",
Expand Down
Loading

0 comments on commit ac86411

Please sign in to comment.