#! /usr/bin/env python3 """ Code generation script for class methods to be exported as public API """ from __future__ import annotations import argparse import ast import os import subprocess import sys from pathlib import Path from textwrap import indent from typing import TYPE_CHECKING import attrs if TYPE_CHECKING: from collections.abc import Iterable, Iterator from typing_extensions import TypeGuard # keep these imports up to date with conditional imports in test_gen_exports # isort: split import astor PREFIX = "_generated" HEADER = """# *********************************************************** # ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ****** # ************************************************************* from __future__ import annotations import sys from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED from ._run import GLOBAL_RUN_CONTEXT """ TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True try: return{}GLOBAL_RUN_CONTEXT.{}.{} except AttributeError: raise RuntimeError("must be called from async context") from None """ @attrs.define class File: path: Path modname: str platform: str = attrs.field(default="", kw_only=True) imports: str = attrs.field(default="", kw_only=True) def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node is either a function or an async function """ return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]: """Check if the AST node has a _public decorator""" if is_function(node): for decorator in node.decorator_list: if isinstance(decorator, ast.Name) and decorator.id == "_public": return True return False def get_public_methods( tree: ast.AST, ) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]: """Return a list of methods marked as public. The function walks the given tree and extracts all objects that are functions which are marked public. """ for node in ast.walk(tree): if is_public(node): yield node def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str: """Given a function definition, create a string that represents taking all the arguments from the function, and passing them through to another invocation of the same function. Example input: ast.parse("def f(a, *, b): ...") Example output: "(a, b=b)" """ call_args = [arg.arg for arg in funcdef.args.args] if funcdef.args.vararg: call_args.append("*" + funcdef.args.vararg.arg) for arg in funcdef.args.kwonlyargs: call_args.append(arg.arg + "=" + arg.arg) # noqa: PERF401 # clarity if funcdef.args.kwarg: call_args.append("**" + funcdef.args.kwarg.arg) return "({})".format(", ".join(call_args)) def run_black(file: File, source: str) -> tuple[bool, str]: """Run black on the specified file. Returns: Tuple of success and result string. ex.: (False, "Failed to run black!\nerror: cannot format ...") (True, "") Raises: ImportError: If black is not installed. """ # imported to check that `subprocess` calls will succeed import black # noqa: F401 # Black has an undocumented API, but it doesn't easily allow reading configuration from # pyproject.toml, and simultaneously pass in / receive the code as a string. # https://github.com/psf/black/issues/779 result = subprocess.run( # "-" as a filename = use stdin, return on stdout. [sys.executable, "-m", "black", "--stdin-filename", file.path, "-"], input=source, capture_output=True, encoding="utf8", ) if result.returncode != 0: return False, f"Failed to run black!\n{result.stderr}" return True, result.stdout def run_ruff(file: File, source: str) -> tuple[bool, str]: """Run ruff on the specified file. Returns: Tuple of success and result string. ex.: (False, "Failed to run ruff!\nerror: Failed to parse ...") (True, "") Raises: ImportError: If ruff is not installed. """ # imported to check that `subprocess` calls will succeed import ruff # noqa: F401 result = subprocess.run( # "-" as a filename = use stdin, return on stdout. [ sys.executable, "-m", "ruff", "check", "--fix", "--unsafe-fixes", "--stdin-filename", file.path, "-", ], input=source, capture_output=True, encoding="utf8", ) if result.returncode != 0: return False, f"Failed to run ruff!\n{result.stderr}" return True, result.stdout def run_linters(file: File, source: str) -> str: """Format the specified file using black and ruff. Returns: Formatted source code. Raises: ImportError: If either is not installed. SystemExit: If either failed. """ success, response = run_black(file, source) if not success: print(response) sys.exit(1) success, response = run_ruff(file, response) if not success: # pragma: no cover # Test for run_ruff should catch print(response) sys.exit(1) success, response = run_black(file, response) if not success: print(response) sys.exit(1) return response def gen_public_wrappers_source(file: File) -> str: """Scan the given .py file for @_public decorators, and generate wrapper functions. """ header = [HEADER] if file.imports: header.append(file.imports) if file.platform: # Simple checks to avoid repeating imports. If this messes up, type checkers/tests will # just give errors. if "TYPE_CHECKING" not in file.imports: header.append("from typing import TYPE_CHECKING\n") if "import sys" not in file.imports: # pragma: no cover header.append("import sys\n") header.append( f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n', ) generated = ["".join(header)] source = astor.code_to_ast.parse_file(file.path) method_names = [] for method in get_public_methods(source): # Remove self from arguments assert method.args.args[0].arg == "self" del method.args.args[0] method_names.append(method.name) for dec in method.decorator_list: # pragma: no cover if isinstance(dec, ast.Name) and dec.id == "contextmanager": is_cm = True break else: is_cm = False # Remove decorators method.decorator_list = [] # Create pass through arguments new_args = create_passthrough_args(method) # Remove method body without the docstring if ast.get_docstring(method) is None: del method.body[:] else: # The first entry is always the docstring del method.body[1:] # Create the function definition including the body func = astor.to_source(method, indent_with=" " * 4) if is_cm: # pragma: no cover func = func.replace("->Iterator", "->ContextManager") # Create export function body template = TEMPLATE.format( " await " if isinstance(method, ast.AsyncFunctionDef) else " ", file.modname, method.name + new_args, ) # Assemble function definition arguments and body snippet = func + indent(template, " " * 4) # Append the snippet to the corresponding module generated.append(snippet) method_names.sort() # Insert after the header, before function definitions generated.insert(1, f"__all__ = {method_names!r}") return "\n\n".join(generated) def matches_disk_files(new_files: dict[str, str]) -> bool: for new_path, new_source in new_files.items(): if not os.path.exists(new_path): return False with open(new_path, encoding="utf-8") as old_file: old_source = old_file.read() if old_source != new_source: return False return True def process(files: Iterable[File], *, do_test: bool) -> None: new_files = {} for file in files: print("Scanning:", file.path) new_source = gen_public_wrappers_source(file) new_source = run_linters(file, new_source) dirname, basename = os.path.split(file.path) new_path = os.path.join(dirname, PREFIX + basename) new_files[new_path] = new_source matches_disk = matches_disk_files(new_files) if do_test: if not matches_disk: print("Generated sources are outdated. Please regenerate.") sys.exit(1) else: print("Generated sources are up to date.") else: for new_path, new_source in new_files.items(): with open(new_path, "w", encoding="utf-8", newline="\n") as f: f.write(new_source) print("Regenerated sources successfully.") if not matches_disk: # With pre-commit integration, show that we edited files. sys.exit(1) # This is in fact run in CI, but only in the formatting check job, which # doesn't collect coverage. def main() -> None: # pragma: no cover parser = argparse.ArgumentParser( description="Generate python code for public api wrappers", ) parser.add_argument( "--test", "-t", action="store_true", help="test if code is still up to date", ) parsed_args = parser.parse_args() source_root = Path.cwd() # Double-check we found the right directory assert (source_root / "LICENSE").exists() core = source_root / "src/trio/_core" to_wrap = [ File(core / "_run.py", "runner", imports=IMPORTS_RUN), File( core / "_instrumentation.py", "runner.instruments", imports=IMPORTS_INSTRUMENT, ), File( core / "_io_windows.py", "runner.io_manager", platform="win32", imports=IMPORTS_WINDOWS, ), File( core / "_io_epoll.py", "runner.io_manager", platform="linux", imports=IMPORTS_EPOLL, ), File( core / "_io_kqueue.py", "runner.io_manager", platform="darwin", imports=IMPORTS_KQUEUE, ), ] process(to_wrap, do_test=parsed_args.test) IMPORTS_RUN = """\ from collections.abc import Awaitable, Callable from typing import Any, TYPE_CHECKING from outcome import Outcome import contextvars from ._run import _NO_SEND, RunStatistics, Task from ._entry_queue import TrioToken from .._abc import Clock if TYPE_CHECKING: from typing_extensions import Unpack from ._run import PosArgT """ IMPORTS_INSTRUMENT = """\ from ._instrumentation import Instrument """ IMPORTS_EPOLL = """\ from typing import TYPE_CHECKING if TYPE_CHECKING: from .._file_io import _HasFileNo """ IMPORTS_KQUEUE = """\ from typing import Callable, ContextManager, TYPE_CHECKING if TYPE_CHECKING: import select from .. import _core from ._traps import Abort, RaiseCancelT from .._file_io import _HasFileNo """ IMPORTS_WINDOWS = """\ from typing import TYPE_CHECKING, ContextManager if TYPE_CHECKING: from .._file_io import _HasFileNo from ._windows_cffi import Handle, CData from typing_extensions import Buffer from ._unbounded_queue import UnboundedQueue """ if __name__ == "__main__": # pragma: no cover main()