From 6b1d9962a2a4f815b9f9d6432fa3cefcef7c1b7d Mon Sep 17 00:00:00 2001 From: Matze Date: Fri, 21 Feb 2025 13:31:22 +0100 Subject: [PATCH 1/4] Check PYTHONPATH value when resolving a package --- CONTRIBUTING.md | 4 +++ docs/building-with-codegen/imports.mdx | 5 ++++ .../cli/mcp/resources/system_prompt.py | 5 ++++ src/codegen/sdk/python/import_resolution.py | 29 +++++++++++++++++++ src/codegen/sdk/system-prompt.txt | 5 ++++ .../test_import_resolution.py | 22 ++++++++++++-- 6 files changed, 68 insertions(+), 2 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a169cf8ad..3f882ddf1 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -51,6 +51,10 @@ uv run pytest tests/unit -n auto uv run pytest tests/integration/codemod/test_codemods.py -n auto ``` +> [!TIP] +> +> - If on Linux the error `OSError: [Errno 24] Too many open files` appears then increase it with `ulimit -n 10000` + ## Pull Request Process 1. Fork the repository and create your branch from `develop`. diff --git a/docs/building-with-codegen/imports.mdx b/docs/building-with-codegen/imports.mdx index c66e7736d..95ecff990 100644 --- a/docs/building-with-codegen/imports.mdx +++ b/docs/building-with-codegen/imports.mdx @@ -69,6 +69,11 @@ print(f"From file: {import_stmt.from_file.filepath}") print(f"To file: {import_stmt.to_file.filepath}") ``` + +With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving +packages. + + ## Working with External Modules You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so: diff --git a/src/codegen/cli/mcp/resources/system_prompt.py b/src/codegen/cli/mcp/resources/system_prompt.py index a44bb2f38..9535570ab 100644 --- a/src/codegen/cli/mcp/resources/system_prompt.py +++ b/src/codegen/cli/mcp/resources/system_prompt.py @@ -2858,6 +2858,11 @@ def validate_data(data: dict) -> bool: print(f"To file: {import_stmt.to_file.filepath}") ``` + +With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving +packages. + + ## Working with External Modules You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so: diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index a80bb2ada..a3fcb1825 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -107,6 +107,10 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if file := self.ctx.get_file(filepath): return ImportResolution(from_file=file, symbol=None, imports_file=True) + # =====[ Check if we are importing an entire file with PYTHONPATH set ]===== + if file := self._file_by_pythonpath(filepath): + return ImportResolution(from_file=file, symbol=None, imports_file=True) + filepath = filepath.replace(".py", "/__init__.py") if file := self.ctx.get_file(filepath): # TODO - I think this is another edge case, due to `dao/__init__.py` etc. @@ -120,6 +124,11 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | symbol = file.get_node_by_name(symbol_name) return ImportResolution(from_file=file, symbol=symbol) + # =====[ Check if `module.py` file exists in the graph with PYTHONPATH set ]===== + if file := self._file_by_pythonpath(filepath): + symbol = file.get_node_by_name(symbol_name) + return ImportResolution(from_file=file, symbol=symbol) + # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): @@ -148,6 +157,26 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | # ext = ExternalModule.from_import(self) # return ImportResolution(symbol=ext) + @noapidoc + @reader + def _file_by_pythonpath(self, filepath: str) -> SourceFile | None: + """Helper to check if a certain import is in the list of files when a PYTHONPATH + is set. Returns either None or the SourceFile. + """ + python_paths: str | None = os.environ.get("PYTHONPATH", None) + if python_paths is None: + return None + + python_path_tokens: list[str] = python_paths.split(":") + for python_path in python_path_tokens: + if len(python_path) == 0: + continue + filepath_new: str = os.path.join(python_path, filepath) + if file := self.ctx.get_file(filepath_new): + return file + + return None + @noapidoc @reader def _relative_to_absolute_import(self, relative_import: str) -> str: diff --git a/src/codegen/sdk/system-prompt.txt b/src/codegen/sdk/system-prompt.txt index 646aa0bfa..ac2e45227 100644 --- a/src/codegen/sdk/system-prompt.txt +++ b/src/codegen/sdk/system-prompt.txt @@ -2879,6 +2879,11 @@ print(f"From file: {import_stmt.from_file.filepath}") print(f"To file: {import_stmt.to_file.filepath}") ``` + +With Python one can specify the `PYTHONPATH` environment variable which is then considered when resolving +packages. + + ## Working with External Modules You can determine if an import references an [ExternalModule](/api-reference/core/ExternalModule) by checking the type of [Import.imported_symbol](/api-reference/core/Import#imported-symbol), like so: diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 6fd9cbe7b..b96d9ae46 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -178,7 +178,7 @@ def test_import_properties_multi_imports(tmpdir) -> None: assert z[7].alias.source == "A" -def test_import_resolution_file_import(tmpdir: str) -> None: +def test_import_resolution_file_import(tmpdir: str, monkeypatch) -> None: """Tests function.usages returns usages from file imports""" # language=python with get_codebase_session( @@ -190,9 +190,14 @@ def update(): """, "consumer.py": """ from a.b.c import src as operations +from b.c import src as operations_pythonpath +from b.c.src import funct_1 def func_1(): operations.update() + +def func_2(): + operations_pythonpath.update() """, }, ) as codebase: @@ -200,13 +205,26 @@ def func_1(): consumer_file: SourceFile = codebase.get_file("consumer.py") # =====[ Imports & Resolution ]===== - assert len(consumer_file.imports) == 1 + assert len(consumer_file.imports) == 3 src_import: Import = consumer_file.imports[0] src_import_resolution: ImportResolution = src_import.resolve_import() assert src_import_resolution assert src_import_resolution.from_file is src_file assert src_import_resolution.imports_file is True + # =====[ Consider PYTHONPATH env variable ]===== + monkeypatch.setenv("PYTHONPATH", "a") + src_import = consumer_file.imports[1] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True + src_import = consumer_file.imports[2] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is False + # =====[ Look at call-sites ]===== update = src_file.get_function("update") call_sites = update.call_sites From 1532c40b41a3e976d690fc64ded31dc0a8260281 Mon Sep 17 00:00:00 2001 From: Matze Date: Fri, 21 Feb 2025 20:38:27 +0100 Subject: [PATCH 2/4] Update CONTRIBUTING.md with PR feedback --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 3f882ddf1..7417db81e 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -53,7 +53,7 @@ uv run pytest tests/integration/codemod/test_codemods.py -n auto > [!TIP] > -> - If on Linux the error `OSError: [Errno 24] Too many open files` appears then increase it with `ulimit -n 10000` +> - If on Linux the error `OSError: [Errno 24] Too many open files` appears then you might want to increase your _ulimit_ ## Pull Request Process From 7a7df7ae69f87b4836858ff504c2d362349044bb Mon Sep 17 00:00:00 2001 From: Matthias Bartelt Date: Mon, 24 Feb 2025 13:38:13 +0100 Subject: [PATCH 3/4] Add flag to enable sys.path package resolving plus custom overrites --- src/codegen/configs/models/codebase.py | 2 + src/codegen/sdk/python/import_resolution.py | 39 ++-- .../test_import_resolution.py | 199 ++++++++++++++++-- 3 files changed, 200 insertions(+), 40 deletions(-) diff --git a/src/codegen/configs/models/codebase.py b/src/codegen/configs/models/codebase.py index c5c96b1ce..22f1ff93a 100644 --- a/src/codegen/configs/models/codebase.py +++ b/src/codegen/configs/models/codebase.py @@ -17,6 +17,8 @@ def __init__(self, prefix: str = "CODEBASE", *args, **kwargs) -> None: disable_graph: bool = False generics: bool = True import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {}) + py_resolve_overrides: list[str] = Field(default_factory=lambda: []) + py_resolve_syspath: bool = False ts_dependency_manager: bool = False ts_language_engine: bool = False v8_ts_engine: bool = False diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index a3fcb1825..6bb900aa0 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -1,6 +1,7 @@ from __future__ import annotations import os +import sys from typing import TYPE_CHECKING from codegen.sdk.core.autocommit import reader @@ -107,9 +108,12 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | if file := self.ctx.get_file(filepath): return ImportResolution(from_file=file, symbol=None, imports_file=True) - # =====[ Check if we are importing an entire file with PYTHONPATH set ]===== - if file := self._file_by_pythonpath(filepath): - return ImportResolution(from_file=file, symbol=None, imports_file=True) + # =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]===== + if len(self.ctx.config.py_resolve_overrides) > 0 or self.ctx.config.py_resolve_syspath: + # Handle resolve overrides first if both is set + resolve_paths: list[str] = self.ctx.config.py_resolve_overrides + (sys.path if self.ctx.config.py_resolve_syspath else []) + if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): + return ImportResolution(from_file=file, symbol=None, imports_file=True) filepath = filepath.replace(".py", "/__init__.py") if file := self.ctx.get_file(filepath): @@ -124,10 +128,13 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | symbol = file.get_node_by_name(symbol_name) return ImportResolution(from_file=file, symbol=symbol) - # =====[ Check if `module.py` file exists in the graph with PYTHONPATH set ]===== - if file := self._file_by_pythonpath(filepath): - symbol = file.get_node_by_name(symbol_name) - return ImportResolution(from_file=file, symbol=symbol) + # =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]===== + if len(self.ctx.config.py_resolve_overrides) > 0 or self.ctx.config.py_resolve_syspath: + # Handle resolve overrides first if both is set + resolve_paths: list[str] = self.ctx.config.py_resolve_overrides + (sys.path if self.ctx.config.py_resolve_syspath else []) + if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): + symbol = file.get_node_by_name(symbol_name) + return ImportResolution(from_file=file, symbol=symbol) # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") @@ -159,19 +166,13 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | @noapidoc @reader - def _file_by_pythonpath(self, filepath: str) -> SourceFile | None: - """Helper to check if a certain import is in the list of files when a PYTHONPATH - is set. Returns either None or the SourceFile. + def _file_by_custom_resolve_paths(self, resolve_paths: list[str], filepath: str) -> SourceFile | None: + """Check if a certain file import can be found within a set sys.path + + Returns either None or the SourceFile. """ - python_paths: str | None = os.environ.get("PYTHONPATH", None) - if python_paths is None: - return None - - python_path_tokens: list[str] = python_paths.split(":") - for python_path in python_path_tokens: - if len(python_path) == 0: - continue - filepath_new: str = os.path.join(python_path, filepath) + for resolve_path in resolve_paths: + filepath_new: str = os.path.join(resolve_path, filepath) if file := self.ctx.get_file(filepath_new): return file diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index b96d9ae46..3ffc078a6 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -1,3 +1,4 @@ +import sys from typing import TYPE_CHECKING from codegen.sdk.codebase.factory.get_session import get_codebase_session @@ -178,7 +179,7 @@ def test_import_properties_multi_imports(tmpdir) -> None: assert z[7].alias.source == "A" -def test_import_resolution_file_import(tmpdir: str, monkeypatch) -> None: +def test_import_resolution_file_import(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python with get_codebase_session( @@ -190,14 +191,9 @@ def update(): """, "consumer.py": """ from a.b.c import src as operations -from b.c import src as operations_pythonpath -from b.c.src import funct_1 -def func_1(): +def func(): operations.update() - -def func_2(): - operations_pythonpath.update() """, }, ) as codebase: @@ -205,26 +201,13 @@ def func_2(): consumer_file: SourceFile = codebase.get_file("consumer.py") # =====[ Imports & Resolution ]===== - assert len(consumer_file.imports) == 3 + assert len(consumer_file.imports) == 1 src_import: Import = consumer_file.imports[0] src_import_resolution: ImportResolution = src_import.resolve_import() assert src_import_resolution assert src_import_resolution.from_file is src_file assert src_import_resolution.imports_file is True - # =====[ Consider PYTHONPATH env variable ]===== - monkeypatch.setenv("PYTHONPATH", "a") - src_import = consumer_file.imports[1] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is True - src_import = consumer_file.imports[2] - src_import_resolution = src_import.resolve_import() - assert src_import_resolution - assert src_import_resolution.from_file is src_file - assert src_import_resolution.imports_file is False - # =====[ Look at call-sites ]===== update = src_file.get_function("update") call_sites = update.call_sites @@ -233,6 +216,180 @@ def func_2(): assert call_site.file == consumer_file +def test_import_resolution_file_syspath_inactive(tmpdir: str, monkeypatch) -> None: + """Tests function.usages returns usages from file imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/b/c/src.py": """ +def update(): + pass +""", + "consumer.py": """ +from b.c import src as operations + +def func(): + operations.update() +""", + }, + ) as codebase: + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Disable resolution via sys.path + codebase.ctx.config.py_resolve_syspath = False + + # =====[ Imports cannot be found without sys.path being set and not active ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution: ImportResolution = src_import.resolve_import() + assert src_import_resolution is None + + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") + + # =====[ Imports cannot be found with sys.path set but not active ]===== + src_import_resolution = src_import.resolve_import() + assert src_import_resolution is None + + +def test_import_resolution_file_syspath_active(tmpdir: str, monkeypatch) -> None: + """Tests function.usages returns usages from file imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/b/c/src.py": """ +def update(): + pass +""", + "consumer.py": """ +from b.c import src as operations + +def func(): + operations.update() +""", + }, + ) as codebase: + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Enable resolution via sys.path + codebase.ctx.config.py_resolve_syspath = True + + # =====[ Imports cannot be found without sys.path being set ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution: ImportResolution = src_import.resolve_import() + assert src_import_resolution is None + + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") + + # =====[ Imports can be found with sys.path set and active ]===== + codebase.ctx.config.py_resolve_syspath = True + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True + + +def test_import_resolution_file_custom_resolve_path(tmpdir: str) -> None: + """Tests function.usages returns usages from file imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/b/c/src.py": """ +def update(): + pass +""", + "consumer.py": """ +from b.c import src as operations +from c import src as operations2 + +def func(): + operations.update() +""", + }, + ) as codebase: + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # =====[ Imports cannot be found without custom resolve path being set ]===== + assert len(consumer_file.imports) == 2 + src_import: Import = consumer_file.imports[0] + src_import_resolution: ImportResolution = src_import.resolve_import() + assert src_import_resolution is None + + # =====[ Imports cannot be found with custom resolve path set to invalid path ]===== + codebase.ctx.config.py_resolve_overrides = ["x"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution is None + + # =====[ Imports can be found with custom resolve path set ]===== + codebase.ctx.config.py_resolve_overrides = ["a"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True + + # =====[ Imports can be found with custom resolve multi-path set ]===== + src_import = consumer_file.imports[1] + codebase.ctx.config.py_resolve_overrides = ["a/b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True + + +def test_import_resolution_file_custom_resolve_and_syspath(tmpdir: str, monkeypatch) -> None: + """Tests function.usages returns usages from file imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/c/src.py": """ +def update1(): + pass +""", + "a/b/c/src.py": """ +def update2(): + pass +""", + "consumer.py": """ +from c import src as operations + +def func(): + operations.update2() +""", + }, + ) as codebase: + src_file: SourceFile = codebase.get_file("a/b/c/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.py_resolve_overrides = [] + codebase.ctx.config.py_resolve_syspath = True + + # =====[ Import with sys.path set can be found ]===== + assert len(consumer_file.imports) == 1 + # Modify sys.path for this test only + monkeypatch.syspath_prepend("a") + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file.file_path == "a/c/src.py" + + # =====[ Imports can be found with custom resolve over sys.path ]===== + codebase.ctx.config.py_resolve_overrides = ["a/b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + assert src_import_resolution.imports_file is True + + def test_import_resolution_circular(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python From 862e8018fd140b1551a444d2d3e9fe29a05d023c Mon Sep 17 00:00:00 2001 From: Matthias Bartelt Date: Tue, 25 Feb 2025 18:15:28 +0100 Subject: [PATCH 4/4] Changed order to have specific import resolutions first. Renaming config variable --- src/codegen/configs/models/codebase.py | 2 +- src/codegen/sdk/python/import_resolution.py | 28 ++++---- .../test_import_resolution.py | 65 +++++++++++++++++-- 3 files changed, 75 insertions(+), 20 deletions(-) diff --git a/src/codegen/configs/models/codebase.py b/src/codegen/configs/models/codebase.py index 22f1ff93a..0e04cead9 100644 --- a/src/codegen/configs/models/codebase.py +++ b/src/codegen/configs/models/codebase.py @@ -16,8 +16,8 @@ def __init__(self, prefix: str = "CODEBASE", *args, **kwargs) -> None: ignore_process_errors: bool = True disable_graph: bool = False generics: bool = True + import_resolution_paths: list[str] = Field(default_factory=lambda: []) import_resolution_overrides: dict[str, str] = Field(default_factory=lambda: {}) - py_resolve_overrides: list[str] = Field(default_factory=lambda: []) py_resolve_syspath: bool = False ts_dependency_manager: bool = False ts_language_engine: bool = False diff --git a/src/codegen/sdk/python/import_resolution.py b/src/codegen/sdk/python/import_resolution.py index 6bb900aa0..7dbce0c7c 100644 --- a/src/codegen/sdk/python/import_resolution.py +++ b/src/codegen/sdk/python/import_resolution.py @@ -105,37 +105,39 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str | base_path, module_source.replace(".", "/") + "/" + symbol_name + ".py", ) - if file := self.ctx.get_file(filepath): - return ImportResolution(from_file=file, symbol=None, imports_file=True) # =====[ Check if we are importing an entire file with custom resolve path or sys.path enabled ]===== - if len(self.ctx.config.py_resolve_overrides) > 0 or self.ctx.config.py_resolve_syspath: + if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: # Handle resolve overrides first if both is set - resolve_paths: list[str] = self.ctx.config.py_resolve_overrides + (sys.path if self.ctx.config.py_resolve_syspath else []) + resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): return ImportResolution(from_file=file, symbol=None, imports_file=True) + # =====[ Default path ]===== + if file := self.ctx.get_file(filepath): + return ImportResolution(from_file=file, symbol=None, imports_file=True) + filepath = filepath.replace(".py", "/__init__.py") if file := self.ctx.get_file(filepath): # TODO - I think this is another edge case, due to `dao/__init__.py` etc. # You can't do `from a.b.c import foo` => `foo.utils.x` right now since `foo` is just a file... return ImportResolution(from_file=file, symbol=None, imports_file=True) - # =====[ Check if `module.py` file exists in the graph ]===== - filepath = module_source.replace(".", "/") + ".py" - filepath = os.path.join(base_path, filepath) - if file := self.ctx.get_file(filepath): - symbol = file.get_node_by_name(symbol_name) - return ImportResolution(from_file=file, symbol=symbol) - # =====[ Check if `module.py` file exists in the graph with custom resolve path or sys.path enabled ]===== - if len(self.ctx.config.py_resolve_overrides) > 0 or self.ctx.config.py_resolve_syspath: + filepath = module_source.replace(".", "/") + ".py" + if len(self.ctx.config.import_resolution_paths) > 0 or self.ctx.config.py_resolve_syspath: # Handle resolve overrides first if both is set - resolve_paths: list[str] = self.ctx.config.py_resolve_overrides + (sys.path if self.ctx.config.py_resolve_syspath else []) + resolve_paths: list[str] = self.ctx.config.import_resolution_paths + (sys.path if self.ctx.config.py_resolve_syspath else []) if file := self._file_by_custom_resolve_paths(resolve_paths, filepath): symbol = file.get_node_by_name(symbol_name) return ImportResolution(from_file=file, symbol=symbol) + # =====[ Check if `module.py` file exists in the graph ]===== + filepath = os.path.join(base_path, filepath) + if file := self.ctx.get_file(filepath): + symbol = file.get_node_by_name(symbol_name) + return ImportResolution(from_file=file, symbol=symbol) + # =====[ Check if `module/__init__.py` file exists in the graph ]===== filepath = filepath.replace(".py", "/__init__.py") if from_file := self.ctx.get_file(filepath): diff --git a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py index 3ffc078a6..735a50d26 100644 --- a/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py +++ b/tests/unit/codegen/sdk/python/import_resolution/test_import_resolution.py @@ -324,12 +324,12 @@ def func(): assert src_import_resolution is None # =====[ Imports cannot be found with custom resolve path set to invalid path ]===== - codebase.ctx.config.py_resolve_overrides = ["x"] + codebase.ctx.config.import_resolution_paths = ["x"] src_import_resolution = src_import.resolve_import() assert src_import_resolution is None # =====[ Imports can be found with custom resolve path set ]===== - codebase.ctx.config.py_resolve_overrides = ["a"] + codebase.ctx.config.import_resolution_paths = ["a"] src_import_resolution = src_import.resolve_import() assert src_import_resolution assert src_import_resolution.from_file is src_file @@ -337,14 +337,14 @@ def func(): # =====[ Imports can be found with custom resolve multi-path set ]===== src_import = consumer_file.imports[1] - codebase.ctx.config.py_resolve_overrides = ["a/b"] + codebase.ctx.config.import_resolution_paths = ["a/b"] src_import_resolution = src_import.resolve_import() assert src_import_resolution assert src_import_resolution.from_file is src_file assert src_import_resolution.imports_file is True -def test_import_resolution_file_custom_resolve_and_syspath(tmpdir: str, monkeypatch) -> None: +def test_import_resolution_file_custom_resolve_and_syspath_precedence(tmpdir: str, monkeypatch) -> None: """Tests function.usages returns usages from file imports""" # language=python with get_codebase_session( @@ -370,7 +370,7 @@ def func(): consumer_file: SourceFile = codebase.get_file("consumer.py") # Ensure we don't have overrites and enable syspath resolution - codebase.ctx.config.py_resolve_overrides = [] + codebase.ctx.config.import_resolution_paths = [] codebase.ctx.config.py_resolve_syspath = True # =====[ Import with sys.path set can be found ]===== @@ -383,13 +383,66 @@ def func(): assert src_import_resolution.from_file.file_path == "a/c/src.py" # =====[ Imports can be found with custom resolve over sys.path ]===== - codebase.ctx.config.py_resolve_overrides = ["a/b"] + codebase.ctx.config.import_resolution_paths = ["a/b"] src_import_resolution = src_import.resolve_import() assert src_import_resolution assert src_import_resolution.from_file is src_file assert src_import_resolution.imports_file is True +def test_import_resolution_default_conflicts_overrite(tmpdir: str, monkeypatch) -> None: + """Tests function.usages returns usages from file imports""" + # language=python + with get_codebase_session( + tmpdir, + files={ + "a/src.py": """ +def update1(): + pass +""", + "b/a/src.py": """ +def update2(): + pass +""", + "consumer.py": """ +from a import src as operations + +def func(): + operations.update2() +""", + }, + ) as codebase: + src_file: SourceFile = codebase.get_file("a/src.py") + src_file_overrite: SourceFile = codebase.get_file("b/a/src.py") + consumer_file: SourceFile = codebase.get_file("consumer.py") + + # Ensure we don't have overrites and enable syspath resolution + codebase.ctx.config.import_resolution_paths = [] + codebase.ctx.config.py_resolve_syspath = True + + # =====[ Default import works ]===== + assert len(consumer_file.imports) == 1 + src_import: Import = consumer_file.imports[0] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is src_file + + # =====[ Sys.path overrite has precedence ]===== + monkeypatch.syspath_prepend("b") + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite + + # =====[ Custom overrite has precedence ]===== + codebase.ctx.config.import_resolution_paths = ["b"] + src_import_resolution = src_import.resolve_import() + assert src_import_resolution + assert src_import_resolution.from_file is not src_file + assert src_import_resolution.from_file is src_file_overrite + + + def test_import_resolution_circular(tmpdir: str) -> None: """Tests function.usages returns usages from file imports""" # language=python