Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 123 additions & 25 deletions pyjavapoet/java_file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
Copyright (C) 2015 Square, Inc.

Expand Down Expand Up @@ -43,12 +43,14 @@
file_comment: Optional[CodeBlock],
indent: str,
static_imports: dict[ClassName, set[str]],
additional_imports: set[str],
):
self.package_name = package_name
self.type_spec = type_spec
self.file_comment = file_comment
self.indent = indent
self.static_imports = static_imports
self.additional_imports = additional_imports

def write_to_dir(self, java_dir: Path) -> Path:
"""
Expand Down Expand Up @@ -92,6 +94,106 @@
self.emit(writer)
# Write to the output
out.write(str(writer))

def _extract_wildcard_imports(self) -> set[str]:
"""Extract wildcard package imports from additional imports.

Returns:
Set of package names that have wildcard imports (e.g., 'java.util' for 'java.util.*')
"""
wildcard_packages = set()

for import_name in self.additional_imports:
if import_name.endswith(".*"):
# Wildcard import - extract package name
package = import_name[:-2] # Remove ".*"
wildcard_packages.add(package)

return wildcard_packages

def _merge_all_specific_imports(
self,
collected_imports: dict[str, set[str]],
wildcard_packages: set[str]
) -> dict[str, set[str]]:
"""Merge all specific imports (collected + additional) while excluding wildcard-covered packages.

Args:
collected_imports: Imports collected from type usage in the code
wildcard_packages: Packages with wildcard imports to exclude

Returns:
Final resolved specific imports (excluding those covered by wildcards)
"""
final_imports = {}

# Add collected imports, but skip those covered by wildcards
for package, simple_names in collected_imports.items():
if package not in wildcard_packages:
if package not in final_imports:
final_imports[package] = set()
final_imports[package].update(simple_names)

# Process additional specific imports and add them
for import_name in self.additional_imports:
if not import_name.endswith(".*") and "." in import_name:
# This is a specific import, not a wildcard
package, simple_name = import_name.rsplit(".", 1)

# Only add if not covered by wildcard
if package not in wildcard_packages:
if package not in final_imports:
final_imports[package] = set()
final_imports[package].add(simple_name)

return final_imports

def _emit_all_imports(
self,
code_writer: CodeWriter,
final_imports: dict[str, set[str]],
wildcard_packages: set[str]
) -> None:
"""Emit all import statements in the correct order.

Args:
code_writer: The CodeWriter to emit to
final_imports: The resolved specific imports to emit
wildcard_packages: The wildcard packages to emit
"""
# Emit static imports first
static_imports = sorted([
f"import static {type_name.canonical_name}.{member};"
for type_name, members in self.static_imports.items()
for member in sorted(members)
])

if static_imports:
for static_import in static_imports:
code_writer.emit(static_import)
code_writer.emit("\n")
code_writer.emit("\n")

# Combine wildcard and specific imports, then sort them together
all_imports = []

# Add wildcard imports
for package in wildcard_packages:
all_imports.append(f"import {package}.*;")

# Add specific imports
for package in final_imports:
for simple_name in final_imports[package]:
all_imports.append(f"import {package}.{simple_name};")

# Sort all imports alphabetically and emit them
for import_statement in sorted(all_imports):
code_writer.emit(import_statement)
code_writer.emit("\n")

# Add blank line after imports if there were any
if all_imports:
code_writer.emit("\n")

def emit(self, code_writer: CodeWriter) -> None:
# Emit file comment
Expand All @@ -110,31 +212,11 @@
)
self.type_spec.emit(import_collector)

# Get the imports
imports = import_collector.get_imports()

# Emit static imports
static_imports = sorted(
[
f"import static {type_name.canonical_name}.{member};"
for type_name, members in self.static_imports.items()
for member in sorted(members)
]
)
if static_imports:
for static_import in static_imports:
code_writer.emit(static_import)
code_writer.emit("\n")
code_writer.emit("\n")

# Emit normal imports
import_packages = sorted(imports.keys())
for package in import_packages:
for simple_name in sorted(imports[package]):
code_writer.emit(f"import {package}.{simple_name};\n")

if import_packages:
code_writer.emit("\n")
# Get the imports from type usage and process all imports
collected_imports = import_collector.get_imports()
wildcard_packages = self._extract_wildcard_imports()
final_imports = self._merge_all_specific_imports(collected_imports, wildcard_packages)
self._emit_all_imports(code_writer, final_imports, wildcard_packages)

# Emit the type
self.type_spec.emit(code_writer)
Expand All @@ -151,6 +233,7 @@
self.file_comment,
self.indent,
self.static_imports,
self.additional_imports,
)

def __str__(self) -> str:
Expand All @@ -174,12 +257,14 @@
file_comment: Optional[CodeBlock] = None,
indent: str = " ",
static_imports: dict[ClassName, set[str]] | None = None,
additional_imports: set[str] | None = None,
):
self.__package_name = package_name
self.__type_spec = type_spec
self.__file_comment = file_comment
self.__indent = indent
self.__static_imports = static_imports or {}
self.__additional_imports = additional_imports or set()

def add_file_comment(self, format_string: str = EMPTY_STRING, *args) -> "JavaFile.Builder":
self.__file_comment = CodeBlock.add_javadoc(self.__file_comment, format_string, *args)
Expand Down Expand Up @@ -207,11 +292,24 @@

return self

def add_additional_import(self, import_name: str) -> "JavaFile.Builder":
"""Add an additional import that will be included in the generated file.

Args:
import_name: The import to add (e.g., "java.util.List" or "java.util.*")

Returns:
This builder for method chaining
"""
self.__additional_imports.add(import_name)
return self

def build(self) -> "JavaFile":
return JavaFile(
self.__package_name,
self.__type_spec,
self.__file_comment,
self.__indent,
self.__static_imports,
self.__additional_imports,
)
61 changes: 61 additions & 0 deletions tests/test_java_file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""
Copyright (C) 2025 Matthew Au-Yeung.

Expand Down Expand Up @@ -492,6 +492,67 @@
self.assertIn("record Point", result)
self.assertIn("implements Serializable", result)

def test_add_additional_import_arbitrary(self):
"""Test adding an arbitrary import unrelated to existing imports."""
# Create a simple class with no external dependencies
type_spec = (
TypeSpec.class_builder("SimpleClass")
.add_modifiers(Modifier.PUBLIC)
.add_method(
MethodSpec.method_builder("doSomething")
.add_modifiers(Modifier.PUBLIC)
.returns("void")
.add_statement("System.out.println($S)", "Hello World")
.build()
)
.build()
)

# Build JavaFile and add an arbitrary import
java_file = (
JavaFile.builder("com.example", type_spec)
.add_additional_import("java.util.Collections")
.build()
)

result = str(java_file)

# Should contain the additional import even though it's not used
self.assertIn("import java.util.Collections;", result)
# Should still contain the class
self.assertIn("public class SimpleClass", result)

def test_add_additional_import_wildcard_covers_existing(self):
"""Test wildcard import that covers an existing specific import."""
# Create a class that uses java.util.List
list_type = ClassName.get("java.util", "List").with_type_arguments(ClassName.get("java.lang", "String"))
type_spec = (
TypeSpec.class_builder("ListClass")
.add_modifiers(Modifier.PUBLIC)
.add_field(
FieldSpec.builder(list_type, "items")
.add_modifiers(Modifier.PRIVATE)
.build()
)
.build()
)

# Build JavaFile and add wildcard import that covers the specific one
java_file = (
JavaFile.builder("com.example", type_spec)
.add_additional_import("java.util.*")
.build()
)

result = str(java_file)

# Should contain the wildcard import
self.assertIn("import java.util.*;", result)
# Should NOT contain the specific List import since it's covered by wildcard
self.assertNotIn("import java.util.List;", result)
# Should still contain the class and use the simple name
self.assertIn("List<String> items", result)


class JavaFileReadWriteTest(unittest.TestCase):
"""Test file reading functionality."""
Expand Down
Loading