413 lines
12 KiB
Python
413 lines
12 KiB
Python
#! /usr/bin/env python3
|
|
"""
|
|
Code generation script for class methods
|
|
to be exported as public API
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import ast
|
|
import os
|
|
import subprocess
|
|
import sys
|
|
from pathlib import Path
|
|
from textwrap import indent
|
|
from typing import TYPE_CHECKING
|
|
|
|
import attrs
|
|
|
|
if TYPE_CHECKING:
|
|
from collections.abc import Iterable, Iterator
|
|
|
|
from typing_extensions import TypeGuard
|
|
|
|
# keep these imports up to date with conditional imports in test_gen_exports
|
|
# isort: split
|
|
import astor
|
|
|
|
PREFIX = "_generated"
|
|
|
|
HEADER = """# ***********************************************************
|
|
# ******* WARNING: AUTOGENERATED! ALL EDITS WILL BE LOST ******
|
|
# *************************************************************
|
|
from __future__ import annotations
|
|
|
|
import sys
|
|
|
|
from ._ki import LOCALS_KEY_KI_PROTECTION_ENABLED
|
|
from ._run import GLOBAL_RUN_CONTEXT
|
|
"""
|
|
|
|
TEMPLATE = """sys._getframe().f_locals[LOCALS_KEY_KI_PROTECTION_ENABLED] = True
|
|
try:
|
|
return{}GLOBAL_RUN_CONTEXT.{}.{}
|
|
except AttributeError:
|
|
raise RuntimeError("must be called from async context") from None
|
|
"""
|
|
|
|
|
|
@attrs.define
|
|
class File:
|
|
path: Path
|
|
modname: str
|
|
platform: str = attrs.field(default="", kw_only=True)
|
|
imports: str = attrs.field(default="", kw_only=True)
|
|
|
|
|
|
def is_function(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
|
|
"""Check if the AST node is either a function
|
|
or an async function
|
|
"""
|
|
return isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef))
|
|
|
|
|
|
def is_public(node: ast.AST) -> TypeGuard[ast.FunctionDef | ast.AsyncFunctionDef]:
|
|
"""Check if the AST node has a _public decorator"""
|
|
if is_function(node):
|
|
for decorator in node.decorator_list:
|
|
if isinstance(decorator, ast.Name) and decorator.id == "_public":
|
|
return True
|
|
return False
|
|
|
|
|
|
def get_public_methods(
|
|
tree: ast.AST,
|
|
) -> Iterator[ast.FunctionDef | ast.AsyncFunctionDef]:
|
|
"""Return a list of methods marked as public.
|
|
The function walks the given tree and extracts
|
|
all objects that are functions which are marked
|
|
public.
|
|
"""
|
|
for node in ast.walk(tree):
|
|
if is_public(node):
|
|
yield node
|
|
|
|
|
|
def create_passthrough_args(funcdef: ast.FunctionDef | ast.AsyncFunctionDef) -> str:
|
|
"""Given a function definition, create a string that represents taking all
|
|
the arguments from the function, and passing them through to another
|
|
invocation of the same function.
|
|
|
|
Example input: ast.parse("def f(a, *, b): ...")
|
|
Example output: "(a, b=b)"
|
|
"""
|
|
call_args = [arg.arg for arg in funcdef.args.args]
|
|
if funcdef.args.vararg:
|
|
call_args.append("*" + funcdef.args.vararg.arg)
|
|
for arg in funcdef.args.kwonlyargs:
|
|
call_args.append(arg.arg + "=" + arg.arg) # noqa: PERF401 # clarity
|
|
if funcdef.args.kwarg:
|
|
call_args.append("**" + funcdef.args.kwarg.arg)
|
|
return "({})".format(", ".join(call_args))
|
|
|
|
|
|
def run_black(file: File, source: str) -> tuple[bool, str]:
|
|
"""Run black on the specified file.
|
|
|
|
Returns:
|
|
Tuple of success and result string.
|
|
ex.:
|
|
(False, "Failed to run black!\nerror: cannot format ...")
|
|
(True, "<formatted source>")
|
|
|
|
Raises:
|
|
ImportError: If black is not installed.
|
|
"""
|
|
# imported to check that `subprocess` calls will succeed
|
|
import black # noqa: F401
|
|
|
|
# Black has an undocumented API, but it doesn't easily allow reading configuration from
|
|
# pyproject.toml, and simultaneously pass in / receive the code as a string.
|
|
# https://github.com/psf/black/issues/779
|
|
result = subprocess.run(
|
|
# "-" as a filename = use stdin, return on stdout.
|
|
[sys.executable, "-m", "black", "--stdin-filename", file.path, "-"],
|
|
input=source,
|
|
capture_output=True,
|
|
encoding="utf8",
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
return False, f"Failed to run black!\n{result.stderr}"
|
|
return True, result.stdout
|
|
|
|
|
|
def run_ruff(file: File, source: str) -> tuple[bool, str]:
|
|
"""Run ruff on the specified file.
|
|
|
|
Returns:
|
|
Tuple of success and result string.
|
|
ex.:
|
|
(False, "Failed to run ruff!\nerror: Failed to parse ...")
|
|
(True, "<formatted source>")
|
|
|
|
Raises:
|
|
ImportError: If ruff is not installed.
|
|
"""
|
|
# imported to check that `subprocess` calls will succeed
|
|
import ruff # noqa: F401
|
|
|
|
result = subprocess.run(
|
|
# "-" as a filename = use stdin, return on stdout.
|
|
[
|
|
sys.executable,
|
|
"-m",
|
|
"ruff",
|
|
"check",
|
|
"--fix",
|
|
"--unsafe-fixes",
|
|
"--stdin-filename",
|
|
file.path,
|
|
"-",
|
|
],
|
|
input=source,
|
|
capture_output=True,
|
|
encoding="utf8",
|
|
)
|
|
|
|
if result.returncode != 0:
|
|
return False, f"Failed to run ruff!\n{result.stderr}"
|
|
return True, result.stdout
|
|
|
|
|
|
def run_linters(file: File, source: str) -> str:
|
|
"""Format the specified file using black and ruff.
|
|
|
|
Returns:
|
|
Formatted source code.
|
|
|
|
Raises:
|
|
ImportError: If either is not installed.
|
|
SystemExit: If either failed.
|
|
"""
|
|
|
|
success, response = run_black(file, source)
|
|
if not success:
|
|
print(response)
|
|
sys.exit(1)
|
|
|
|
success, response = run_ruff(file, response)
|
|
if not success: # pragma: no cover # Test for run_ruff should catch
|
|
print(response)
|
|
sys.exit(1)
|
|
|
|
success, response = run_black(file, response)
|
|
if not success:
|
|
print(response)
|
|
sys.exit(1)
|
|
|
|
return response
|
|
|
|
|
|
def gen_public_wrappers_source(file: File) -> str:
|
|
"""Scan the given .py file for @_public decorators, and generate wrapper
|
|
functions.
|
|
|
|
"""
|
|
header = [HEADER]
|
|
|
|
if file.imports:
|
|
header.append(file.imports)
|
|
if file.platform:
|
|
# Simple checks to avoid repeating imports. If this messes up, type checkers/tests will
|
|
# just give errors.
|
|
if "TYPE_CHECKING" not in file.imports:
|
|
header.append("from typing import TYPE_CHECKING\n")
|
|
if "import sys" not in file.imports: # pragma: no cover
|
|
header.append("import sys\n")
|
|
header.append(
|
|
f'\nassert not TYPE_CHECKING or sys.platform=="{file.platform}"\n',
|
|
)
|
|
|
|
generated = ["".join(header)]
|
|
|
|
source = astor.code_to_ast.parse_file(file.path)
|
|
method_names = []
|
|
for method in get_public_methods(source):
|
|
# Remove self from arguments
|
|
assert method.args.args[0].arg == "self"
|
|
del method.args.args[0]
|
|
method_names.append(method.name)
|
|
|
|
for dec in method.decorator_list: # pragma: no cover
|
|
if isinstance(dec, ast.Name) and dec.id == "contextmanager":
|
|
is_cm = True
|
|
break
|
|
else:
|
|
is_cm = False
|
|
|
|
# Remove decorators
|
|
method.decorator_list = []
|
|
|
|
# Create pass through arguments
|
|
new_args = create_passthrough_args(method)
|
|
|
|
# Remove method body without the docstring
|
|
if ast.get_docstring(method) is None:
|
|
del method.body[:]
|
|
else:
|
|
# The first entry is always the docstring
|
|
del method.body[1:]
|
|
|
|
# Create the function definition including the body
|
|
func = astor.to_source(method, indent_with=" " * 4)
|
|
|
|
if is_cm: # pragma: no cover
|
|
func = func.replace("->Iterator", "->ContextManager")
|
|
|
|
# Create export function body
|
|
template = TEMPLATE.format(
|
|
" await " if isinstance(method, ast.AsyncFunctionDef) else " ",
|
|
file.modname,
|
|
method.name + new_args,
|
|
)
|
|
|
|
# Assemble function definition arguments and body
|
|
snippet = func + indent(template, " " * 4)
|
|
|
|
# Append the snippet to the corresponding module
|
|
generated.append(snippet)
|
|
|
|
method_names.sort()
|
|
# Insert after the header, before function definitions
|
|
generated.insert(1, f"__all__ = {method_names!r}")
|
|
return "\n\n".join(generated)
|
|
|
|
|
|
def matches_disk_files(new_files: dict[str, str]) -> bool:
|
|
for new_path, new_source in new_files.items():
|
|
if not os.path.exists(new_path):
|
|
return False
|
|
with open(new_path, encoding="utf-8") as old_file:
|
|
old_source = old_file.read()
|
|
if old_source != new_source:
|
|
return False
|
|
return True
|
|
|
|
|
|
def process(files: Iterable[File], *, do_test: bool) -> None:
|
|
new_files = {}
|
|
for file in files:
|
|
print("Scanning:", file.path)
|
|
new_source = gen_public_wrappers_source(file)
|
|
new_source = run_linters(file, new_source)
|
|
dirname, basename = os.path.split(file.path)
|
|
new_path = os.path.join(dirname, PREFIX + basename)
|
|
new_files[new_path] = new_source
|
|
matches_disk = matches_disk_files(new_files)
|
|
if do_test:
|
|
if not matches_disk:
|
|
print("Generated sources are outdated. Please regenerate.")
|
|
sys.exit(1)
|
|
else:
|
|
print("Generated sources are up to date.")
|
|
else:
|
|
for new_path, new_source in new_files.items():
|
|
with open(new_path, "w", encoding="utf-8", newline="\n") as f:
|
|
f.write(new_source)
|
|
print("Regenerated sources successfully.")
|
|
if not matches_disk:
|
|
# With pre-commit integration, show that we edited files.
|
|
sys.exit(1)
|
|
|
|
|
|
# This is in fact run in CI, but only in the formatting check job, which
|
|
# doesn't collect coverage.
|
|
def main() -> None: # pragma: no cover
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate python code for public api wrappers",
|
|
)
|
|
parser.add_argument(
|
|
"--test",
|
|
"-t",
|
|
action="store_true",
|
|
help="test if code is still up to date",
|
|
)
|
|
parsed_args = parser.parse_args()
|
|
|
|
source_root = Path.cwd()
|
|
# Double-check we found the right directory
|
|
assert (source_root / "LICENSE").exists()
|
|
core = source_root / "src/trio/_core"
|
|
to_wrap = [
|
|
File(core / "_run.py", "runner", imports=IMPORTS_RUN),
|
|
File(
|
|
core / "_instrumentation.py",
|
|
"runner.instruments",
|
|
imports=IMPORTS_INSTRUMENT,
|
|
),
|
|
File(
|
|
core / "_io_windows.py",
|
|
"runner.io_manager",
|
|
platform="win32",
|
|
imports=IMPORTS_WINDOWS,
|
|
),
|
|
File(
|
|
core / "_io_epoll.py",
|
|
"runner.io_manager",
|
|
platform="linux",
|
|
imports=IMPORTS_EPOLL,
|
|
),
|
|
File(
|
|
core / "_io_kqueue.py",
|
|
"runner.io_manager",
|
|
platform="darwin",
|
|
imports=IMPORTS_KQUEUE,
|
|
),
|
|
]
|
|
|
|
process(to_wrap, do_test=parsed_args.test)
|
|
|
|
|
|
IMPORTS_RUN = """\
|
|
from collections.abc import Awaitable, Callable
|
|
from typing import Any, TYPE_CHECKING
|
|
|
|
from outcome import Outcome
|
|
import contextvars
|
|
|
|
from ._run import _NO_SEND, RunStatistics, Task
|
|
from ._entry_queue import TrioToken
|
|
from .._abc import Clock
|
|
|
|
if TYPE_CHECKING:
|
|
from typing_extensions import Unpack
|
|
from ._run import PosArgT
|
|
"""
|
|
IMPORTS_INSTRUMENT = """\
|
|
from ._instrumentation import Instrument
|
|
"""
|
|
|
|
IMPORTS_EPOLL = """\
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from .._file_io import _HasFileNo
|
|
"""
|
|
|
|
IMPORTS_KQUEUE = """\
|
|
from typing import Callable, ContextManager, TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
import select
|
|
|
|
from .. import _core
|
|
from ._traps import Abort, RaiseCancelT
|
|
from .._file_io import _HasFileNo
|
|
"""
|
|
|
|
IMPORTS_WINDOWS = """\
|
|
from typing import TYPE_CHECKING, ContextManager
|
|
|
|
if TYPE_CHECKING:
|
|
from .._file_io import _HasFileNo
|
|
from ._windows_cffi import Handle, CData
|
|
from typing_extensions import Buffer
|
|
|
|
from ._unbounded_queue import UnboundedQueue
|
|
"""
|
|
|
|
|
|
if __name__ == "__main__": # pragma: no cover
|
|
main()
|