Skip to content

Commit

Permalink
Revert "Fix module resolution bug" (#227)
Browse files Browse the repository at this point in the history
Reverts #190
  • Loading branch information
bagel897 authored Jan 30, 2025
1 parent efaeb77 commit b0e3670
Show file tree
Hide file tree
Showing 6 changed files with 14 additions and 110 deletions.
8 changes: 3 additions & 5 deletions src/codegen/sdk/core/external_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ class ExternalModule(
"""

node_type: Literal[NodeType.EXTERNAL] = NodeType.EXTERNAL
_import: Import | None = None

def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name, import_node: Import | None = None) -> None:
def __init__(self, ts_node: TSNode, file_node_id: NodeId, G: CodebaseGraph, import_name: Name) -> None:
self.node_id = G.add_node(self)
super().__init__(ts_node, file_node_id, G, None)
self._name_node = import_name
self.return_type = StubPlaceholder(parent=self)
assert self._idx_key not in self.G._ext_module_idx
self.G._ext_module_idx[self._idx_key] = self.node_id
self._import = import_node

@property
def _idx_key(self) -> str:
Expand All @@ -70,7 +68,7 @@ def from_import(cls, imp: Import) -> ExternalModule:
Returns:
ExternalModule: A new ExternalModule instance representing the external module.
"""
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node, imp)
return cls(imp.ts_node, imp.file_node_id, imp.G, imp._unique_node)

@property
@reader
Expand Down Expand Up @@ -138,7 +136,7 @@ def viz(self) -> VizNode:
@noapidoc
@reader
def resolve_attribute(self, name: str) -> ExternalModule | None:
return self._import.resolve_attribute(name) or self
return self

@noapidoc
@commiter
Expand Down
16 changes: 2 additions & 14 deletions src/codegen/sdk/core/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from codegen.sdk.core.expressions.name import Name
from codegen.sdk.core.external_module import ExternalModule
from codegen.sdk.core.interfaces.chainable import Chainable
from codegen.sdk.core.interfaces.has_attribute import HasAttribute
from codegen.sdk.core.interfaces.usable import Usable
from codegen.sdk.core.statements.import_statement import ImportStatement
from codegen.sdk.enums import EdgeType, ImportType, NodeType
Expand Down Expand Up @@ -58,7 +57,7 @@ class ImportResolution(Generic[TSourceFile]):


@apidoc
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile], HasAttribute[TSourceFile]):
class Import(Usable[ImportStatement], Chainable, Generic[TSourceFile]):
"""Represents a single symbol being imported.
For example, this is one `Import` in Python (and similar applies to Typescript, etc.):
Expand Down Expand Up @@ -116,7 +115,7 @@ def __rich_repr__(self) -> rich.repr.Result:

@noapidoc
@abstractmethod
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSourceFile] | None:
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSourceFile] | None:
"""Resolves the import to a symbol defined outside the file.
Returns an ImportResolution object.
Expand Down Expand Up @@ -663,17 +662,6 @@ def remove_if_unused(self) -> None:
):
self.remove()

@noapidoc
@reader
def resolve_attribute(self, attribute: str) -> TSourceFile | None:
# Handles implicit namespace imports in python
if not isinstance(self._imported_symbol(), ExternalModule):
return None
resolved = self.resolve_import(add_module_name=attribute)
if resolved:
return resolved.symbol or resolved.from_file
return None


TImport = TypeVar("TImport", bound="Import")

Expand Down
20 changes: 1 addition & 19 deletions src/codegen/sdk/python/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from codegen.sdk.core.file import SourceFile
from codegen.sdk.core.interface import Interface
from codegen.sdk.enums import ImportType, ProgrammingLanguage
from codegen.sdk.extensions.utils import cached_property, iter_all_descendants
from codegen.sdk.extensions.utils import iter_all_descendants
from codegen.sdk.python import PyAssignment
from codegen.sdk.python.class_definition import PyClass
from codegen.sdk.python.detached_symbols.code_block import PyCodeBlock
Expand All @@ -20,7 +20,6 @@

if TYPE_CHECKING:
from codegen.sdk.codebase.codebase_graph import CodebaseGraph
from codegen.sdk.core.import_resolution import WildcardImport
from codegen.sdk.python.symbol import PySymbol


Expand Down Expand Up @@ -174,20 +173,3 @@ def add_import_from_import_string(self, import_string: str) -> None:
def remove_unused_exports(self) -> None:
"""Removes unused exports from the file. NO-OP for python"""
pass

@cached_property
@noapidoc
@reader(cache=True)
def valid_import_names(self) -> dict[str, PySymbol | PyImport | WildcardImport[PyImport]]:
"""Returns a dict mapping name => Symbol (or import) in this file that can be imported from
another file.
"""
if self.name == "__init__":
ret = {}
if self.directory:
for file in self.directory:
if file.name == "__init__":
continue
ret[file.name] = file
return ret
return super().valid_import_names
17 changes: 7 additions & 10 deletions src/codegen/sdk/python/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,10 @@ def imported_exports(self) -> list[Exportable]:

@noapidoc
@reader
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[PyFile] | None:
def resolve_import(self, base_path: str | None = None) -> ImportResolution[PyFile] | None:
base_path = base_path or self.G.projects[0].base_path or ""
module_source = self.module.source if self.module else ""
symbol_name = self.symbol_name.source if self.symbol_name else ""
if add_module_name:
module_source += f".{symbol_name}"
symbol_name = add_module_name

# If import is relative, convert to absolute path
if module_source.startswith("."):
module_source = self._relative_to_absolute_import(module_source)
Expand All @@ -102,7 +99,7 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
# `from a.b.c import foo`
filepath = os.path.join(
base_path,
module_source.replace(".", "/") + "/" + symbol_name + ".py",
module_source.replace(".", "/") + "/" + self.symbol_name.source + ".py",
)
if file := self.G.get_file(filepath):
return ImportResolution(from_file=file, symbol=None, imports_file=True)
Expand All @@ -117,22 +114,22 @@ def resolve_import(self, base_path: str | None = None, *, add_module_name: str |
filepath = module_source.replace(".", "/") + ".py"
filepath = os.path.join(base_path, filepath)
if file := self.G.get_file(filepath):
symbol = file.get_node_by_name(symbol_name)
symbol = file.get_node_by_name(self.symbol_name.source)
return ImportResolution(from_file=file, symbol=symbol)

# =====[ Check if `module/__init__.py` file exists in the graph ]=====
filepath = filepath.replace(".py", "/__init__.py")
if from_file := self.G.get_file(filepath):
symbol = from_file.get_node_by_name(symbol_name)
symbol = from_file.get_node_by_name(self.symbol_name.source)
return ImportResolution(from_file=from_file, symbol=symbol)

# =====[ Case: Can't resolve the import ]=====
if base_path == "":
# Try to resolve with "src" as the base path
return self.resolve_import(base_path="src", add_module_name=add_module_name)
return self.resolve_import(base_path="src")
if base_path == "src":
# Try "test" next
return self.resolve_import(base_path="test", add_module_name=add_module_name)
return self.resolve_import(base_path="test")

# if not G_override:
# for resolver in G.import_resolvers:
Expand Down
2 changes: 1 addition & 1 deletion src/codegen/sdk/typescript/import_resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ def resolved_symbol(self) -> Symbol | ExternalModule | TSFile | None:
return resolved_symbol

@reader
def resolve_import(self, base_path: str | None = None, *, add_module_name: str | None = None) -> ImportResolution[TSFile] | None:
def resolve_import(self, base_path: str | None = None) -> ImportResolution[TSFile] | None:
"""Resolves an import statement to its target file and symbol.
This method is used by GraphBuilder to resolve import statements to their target files and symbols. It handles both relative and absolute imports,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,64 +249,3 @@ def c_sym():
assert "c_sym" in b_file.valid_symbol_names
assert "a_sym" in c_file.valid_symbol_names
assert "b_sym" in c_file.valid_symbol_names.keys()


def test_import_resolution_nested_module(tmpdir: str) -> None:
"""Tests import resolution works with nested module imports"""
# language=python
with get_codebase_session(
tmpdir,
files={
"a/b/c.py": """
def d():
pass
""",
"consumer.py": """
from a import b
b.c.d()
""",
},
) as codebase:
consumer_file: SourceFile = codebase.get_file("consumer.py")
c_file: SourceFile = codebase.get_file("a/b/c.py")

# Verify import resolution
assert len(consumer_file.imports) == 1

# Verify function call resolution
d_func = c_file.get_function("d")
call_sites = d_func.call_sites
assert len(call_sites) == 1
assert call_sites[0].file == consumer_file


def test_import_resolution_nested_module_init(tmpdir: str) -> None:
"""Tests import resolution works with nested module imports"""
# language=python
with get_codebase_session(
tmpdir,
files={
"a/b/c.py": """
def d():
pass
""",
"a/b/__init__.py": """""",
"consumer.py": """
from a import b
b.c.d()
""",
},
) as codebase:
consumer_file: SourceFile = codebase.get_file("consumer.py")
c_file: SourceFile = codebase.get_file("a/b/c.py")

# Verify import resolution
assert len(consumer_file.imports) == 1

# Verify function call resolution
d_func = c_file.get_function("d")
call_sites = d_func.call_sites
assert len(call_sites) == 1
assert call_sites[0].file == consumer_file

0 comments on commit b0e3670

Please sign in to comment.