Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow custom overrites for package resolving and optional sys.path support #601

Merged
merged 9 commits into from
Feb 26, 2025
4 changes: 4 additions & 0 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ uv run pytest tests/unit -n auto
uv run pytest tests/integration/codemod/test_codemods.py -n auto
```

> [!TIP]
>
> - If on Linux the error `OSError: [Errno 24] Too many open files` appears then increase it with `ulimit -n 10000`

## Pull Request Process

1. Fork the repository and create your branch from `develop`.
Expand Down
5 changes: 5 additions & 0 deletions docs/building-with-codegen/imports.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ print(f"From file: {import_stmt.from_file.filepath}")
print(f"To file: {import_stmt.to_file.filepath}")
```

<Note>
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
packages.
</Note>

## Working with External Modules

You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:
Expand Down
5 changes: 5 additions & 0 deletions src/codegen/cli/mcp/resources/system_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -2858,6 +2858,11 @@ def validate_data(data: dict) -> bool:
print(f"To file: {import_stmt.to_file.filepath}")
```
<Note>
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
packages.
</Note>
## Working with External Modules
You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:
Expand Down
29 changes: 29 additions & 0 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
if file := self.ctx.get_file(filepath):
return ImportResolution(from_file=file, symbol=None, imports_file=True)

# =====[ Check if we are importing an entire file with PYTHONPATH set ]=====
if file := self._file_by_pythonpath(filepath):
return ImportResolution(from_file=file, symbol=None, imports_file=True)

filepath = filepath.replace(".py", "/__init__.py")
if file := self.ctx.get_file(filepath):
# TODO - I think this is another edge case, due to `dao/__init__.py` etc.
Expand All @@ -120,6 +124,11 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
symbol = file.get_node_by_name(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module.py` file exists in the graph with PYTHONPATH set ]=====
if file := self._file_by_pythonpath(filepath):
symbol = file.get_node_by_name(symbol_name)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.ctx.get_file(filepath):
Expand Down Expand Up @@ -148,6 +157,26 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
# ext = ExternalModule.from_import(self)
# return ImportResolution(symbol=ext)

@noapidoc
@reader
def _file_by_pythonpath(self, filepath: str) -> SourceFile | None:
"""Helper to check if a certain import is in the list of files when a PYTHONPATH
is set. Returns either None or the SourceFile.
"""
python_paths: str | None = os.environ.get("PYTHONPATH", None)
if python_paths is None:
return None

python_path_tokens: list[str] = python_paths.split(":")
for python_path in python_path_tokens:
if len(python_path) == 0:
continue
filepath_new: str = os.path.join(python_path, filepath)
if file := self.ctx.get_file(filepath_new):
return file

return None

@noapidoc
@reader
def _relative_to_absolute_import(self, relative_import: str) -> str:
Expand Down
5 changes: 5 additions & 0 deletions src/codegen/sdk/system-prompt.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2879,6 +2879,11 @@ print(f"From file: {import_stmt.from_file.filepath}")
print(f"To file: {import_stmt.to_file.filepath}")
```

<Note>
With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving
packages.
</Note>

## Working with External Modules

You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def test_import_properties_multi_imports(tmpdir) -> None:
assert z[7].alias.source == "A"


def test_import_resolution_file_import(tmpdir: str) -> None:
def test_import_resolution_file_import(tmpdir: str, monkeypatch) -> None:
"""Tests function.usages returns usages from file imports"""
# language=python
with get_codebase_session(
Expand All @@ -190,23 +190,41 @@ def update():
""",
"consumer.py": """
from a.b.c import src as operations
from b.c import src as operations_pythonpath
from b.c.src import funct_1

def func_1():
operations.update()

def func_2():
operations_pythonpath.update()
""",
},
) as codebase:
src_file: SourceFile = codebase.get_file("a/b/c/src.py")
consumer_file: SourceFile = codebase.get_file("consumer.py")

# =====[ Imports & Resolution ]=====
assert len(consumer_file.imports) == 1
assert len(consumer_file.imports) == 3
src_import: Import = consumer_file.imports[0]
src_import_resolution: ImportResolution = src_import.resolve_import()
assert src_import_resolution
assert src_import_resolution.from_file is src_file
assert src_import_resolution.imports_file is True

# =====[ Consider PYTHONPATH env variable ]=====
monkeypatch.setenv("PYTHONPATH", "a")
src_import = consumer_file.imports[1]
src_import_resolution = src_import.resolve_import()
assert src_import_resolution
assert src_import_resolution.from_file is src_file
assert src_import_resolution.imports_file is True
src_import = consumer_file.imports[2]
src_import_resolution = src_import.resolve_import()
assert src_import_resolution
assert src_import_resolution.from_file is src_file
assert src_import_resolution.imports_file is False

# =====[ Look at call-sites ]=====
update = src_file.get_function("update")
call_sites = update.call_sites
Expand Down