Merge pull request #24262 from mnesarco/pyi-fixes-1
This commit is contained in:
@@ -2,7 +2,10 @@
|
||||
|
||||
"""Parses Python binding interface files into a typed AST model."""
|
||||
|
||||
import ast, re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
import ast
|
||||
import re
|
||||
from typing import List
|
||||
from model.typedModel import (
|
||||
GenerateModel,
|
||||
@@ -16,6 +19,221 @@ from model.typedModel import (
|
||||
SequenceProtocol,
|
||||
)
|
||||
|
||||
SIGNATURE_SEP = re.compile(r"\s+--\s+", re.DOTALL)
|
||||
SELF_CLS_ARG = re.compile(r"\(\s*(self|cls)(\s*,\s*)?")
|
||||
|
||||
|
||||
class ArgumentKind(Enum):
|
||||
PositionOnly = 0
|
||||
Arg = 1
|
||||
VarArg = 2
|
||||
KwOnly = 3
|
||||
KwArg = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class FuncArgument:
|
||||
name: str
|
||||
annotation: str
|
||||
kind: ArgumentKind
|
||||
|
||||
|
||||
class FunctionSignature:
|
||||
"""
|
||||
Parse function arguments with correct classification and order.
|
||||
"""
|
||||
|
||||
args: list[FuncArgument]
|
||||
has_keywords: bool
|
||||
docstring: str
|
||||
annotated_text: str
|
||||
text: str
|
||||
|
||||
const_flag: bool = False
|
||||
static_flag: bool = False
|
||||
class_flag: bool = False
|
||||
noargs_flag: bool = False
|
||||
is_overload: bool = False
|
||||
|
||||
def __init__(self, func: ast.FunctionDef):
|
||||
self.args = []
|
||||
self.has_keywords = False
|
||||
self.is_overload = False
|
||||
self.docstring = ast.get_docstring(func) or ""
|
||||
|
||||
args = func.args
|
||||
self.update_flags(func)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.PositionOnly,
|
||||
)
|
||||
for arg in args.posonlyargs
|
||||
),
|
||||
)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.Arg,
|
||||
)
|
||||
for arg in args.args
|
||||
),
|
||||
)
|
||||
|
||||
# tricky part to determine if there are keyword arguments or not
|
||||
if args.args:
|
||||
if args.args[0].arg in ("self", "cls"):
|
||||
instance_args = len(args.args) > 1
|
||||
else:
|
||||
instance_args = True
|
||||
else:
|
||||
instance_args = False
|
||||
|
||||
self.has_keywords = bool(instance_args or args.kwonlyargs or args.kwarg)
|
||||
|
||||
if args.vararg:
|
||||
self.args.append(
|
||||
FuncArgument(
|
||||
args.vararg.arg,
|
||||
self.get_annotation_str(args.vararg.annotation),
|
||||
ArgumentKind.VarArg,
|
||||
),
|
||||
)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.KwOnly,
|
||||
)
|
||||
for arg in args.kwonlyargs
|
||||
),
|
||||
)
|
||||
|
||||
if args.kwarg:
|
||||
self.args.append(
|
||||
FuncArgument(
|
||||
args.kwarg.arg,
|
||||
self.get_annotation_str(args.kwarg.annotation),
|
||||
ArgumentKind.KwArg,
|
||||
),
|
||||
)
|
||||
|
||||
# Annotated signatures (Not supported by __text_signature__)
|
||||
returns = ast.unparse(func.returns) if func.returns else "object"
|
||||
parameters = ast.unparse(func.args)
|
||||
self.annotated_text = SELF_CLS_ARG.sub("(", f"{func.name}({parameters}) -> {returns}", 1)
|
||||
|
||||
# Not Annotated signatures (supported by __text_signature__)
|
||||
all_args = [*args.posonlyargs, *args.args, args.vararg, *args.kwonlyargs, args.kwarg]
|
||||
for item in all_args:
|
||||
if item:
|
||||
item.annotation = None
|
||||
parameters = ast.unparse(args)
|
||||
self.text = SELF_CLS_ARG.sub(r"($\1\2", f"{func.name}({parameters})", 1)
|
||||
|
||||
def get_annotation_str(self, node: ast.AST | None) -> str:
|
||||
if not node:
|
||||
return "object"
|
||||
return ast.unparse(node)
|
||||
|
||||
def update_flags(self, func: ast.FunctionDef) -> None:
|
||||
for deco in func.decorator_list:
|
||||
match deco:
|
||||
case ast.Name(id, _):
|
||||
name = id
|
||||
case ast.Attribute(_, attr, _):
|
||||
name = attr
|
||||
case _:
|
||||
continue
|
||||
|
||||
match name:
|
||||
case "constmethod":
|
||||
self.const_flag = True
|
||||
case "classmethod":
|
||||
self.class_flag = True
|
||||
case "no_args":
|
||||
self.noargs_flag = True
|
||||
case "staticmethod":
|
||||
self.static_flag = True
|
||||
case "overload":
|
||||
self.is_overload = True
|
||||
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
signatures: list[FunctionSignature]
|
||||
|
||||
def __init__(self, func: ast.FunctionDef) -> None:
|
||||
self.name = func.name
|
||||
self.signatures = [FunctionSignature(func)]
|
||||
|
||||
def update(self, func: ast.FunctionDef) -> None:
|
||||
self.signatures.append(FunctionSignature(func))
|
||||
|
||||
@property
|
||||
def docstring(self) -> str:
|
||||
return "\n".join((f.docstring for f in self.signatures))
|
||||
|
||||
@property
|
||||
def has_keywords(self) -> bool:
|
||||
overloads = len(self.signatures) > 1
|
||||
if overloads:
|
||||
return any(sig.has_keywords for sig in self.signatures if sig.is_overload)
|
||||
return self.signatures[0].has_keywords
|
||||
|
||||
@property
|
||||
def signature(self) -> FunctionSignature | None:
|
||||
"""First non overload signature"""
|
||||
for sig in self.signatures:
|
||||
if not sig.is_overload:
|
||||
return sig
|
||||
return None
|
||||
|
||||
@property
|
||||
def static_flag(self) -> bool:
|
||||
return any(sig.static_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def const_flag(self) -> bool:
|
||||
return any(sig.const_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def class_flag(self) -> bool:
|
||||
return any(sig.class_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def noargs_flag(self) -> bool:
|
||||
return any(sig.noargs_flag for sig in self.signatures)
|
||||
|
||||
def add_signature_docs(self, doc: Documentation) -> None:
|
||||
if len(self.signatures) == 1:
|
||||
docstring = [self.signatures[0].text]
|
||||
signature = [self.signatures[0].annotated_text]
|
||||
else:
|
||||
docstring = [sig.text for sig in self.signatures if not sig.is_overload]
|
||||
signature = [sig.annotated_text for sig in self.signatures if sig.is_overload]
|
||||
|
||||
if not docstring:
|
||||
return
|
||||
|
||||
user_doc = doc.UserDocu or ""
|
||||
marker = SIGNATURE_SEP.search(user_doc)
|
||||
if marker:
|
||||
user_doc = user_doc[marker.end() :].strip()
|
||||
|
||||
docstring.append("--\n") # mark __text_signature__
|
||||
docstring.extend(signature) # Include real annotated signature in user docstring
|
||||
docstring.append(f"\n{user_doc}") # Rest of the docstring
|
||||
doc.UserDocu = "\n".join(docstring)
|
||||
|
||||
|
||||
def _extract_decorator_kwargs(decorator: ast.expr) -> dict:
|
||||
"""
|
||||
@@ -201,7 +419,12 @@ def _parse_class_attributes(class_node: ast.ClassDef, source_code: str) -> List[
|
||||
attr_doc = _parse_docstring_for_documentation(docstring)
|
||||
|
||||
param = Parameter(Name=name, Type=param_type)
|
||||
attr = Attribute(Documentation=attr_doc, Parameter=param, Name=name, ReadOnly=readonly)
|
||||
attr = Attribute(
|
||||
Documentation=attr_doc,
|
||||
Parameter=param,
|
||||
Name=name,
|
||||
ReadOnly=readonly,
|
||||
)
|
||||
attributes.append(attr)
|
||||
|
||||
return attributes
|
||||
@@ -216,7 +439,7 @@ def _parse_methods(class_node: ast.ClassDef) -> List[Methode]:
|
||||
"""
|
||||
methods = []
|
||||
|
||||
def collect_function_defs(nodes):
|
||||
def collect_function_defs(nodes) -> list[ast.FunctionDef]:
|
||||
funcs = []
|
||||
for node in nodes:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
@@ -226,99 +449,43 @@ def _parse_methods(class_node: ast.ClassDef) -> List[Methode]:
|
||||
funcs.extend(collect_function_defs(node.orelse))
|
||||
return funcs
|
||||
|
||||
for stmt in collect_function_defs(class_node.body):
|
||||
# Skip methods decorated with @overload
|
||||
skip_method = False
|
||||
for deco in stmt.decorator_list:
|
||||
match deco:
|
||||
case ast.Name(id="overload"):
|
||||
skip_method = True
|
||||
break
|
||||
case ast.Attribute(attr="overload"):
|
||||
skip_method = True
|
||||
break
|
||||
case _:
|
||||
pass
|
||||
if skip_method:
|
||||
continue
|
||||
# Collect including overloads
|
||||
functions: dict[str, Function] = {}
|
||||
for func_node in collect_function_defs(class_node.body):
|
||||
if func := functions.get(func_node.name):
|
||||
func.update(func_node)
|
||||
else:
|
||||
functions[func_node.name] = Function(func_node)
|
||||
|
||||
# Extract method name
|
||||
method_name = stmt.name
|
||||
|
||||
# Extract docstring
|
||||
method_docstring = ast.get_docstring(stmt) or ""
|
||||
doc_obj = _parse_docstring_for_documentation(method_docstring)
|
||||
has_keyword_args = False
|
||||
for func in functions.values():
|
||||
doc_obj = _parse_docstring_for_documentation(func.docstring)
|
||||
func.add_signature_docs(doc_obj)
|
||||
method_params = []
|
||||
|
||||
# Helper for extracting an annotation string
|
||||
def get_annotation_str(annotation):
|
||||
match annotation:
|
||||
case ast.Name(id=name):
|
||||
return name
|
||||
case ast.Attribute(value=ast.Name(id=name), attr=attr):
|
||||
return f"{name}.{attr}"
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=_):
|
||||
return name
|
||||
case ast.Subscript(
|
||||
value=ast.Attribute(value=ast.Name(id=name), attr=attr), slice=_
|
||||
):
|
||||
return f"{name}.{attr}"
|
||||
case _:
|
||||
return "object"
|
||||
signature = func.signature
|
||||
if signature is None:
|
||||
continue
|
||||
|
||||
# Process positional parameters (skipping self/cls)
|
||||
for arg in stmt.args.args:
|
||||
param_name = arg.arg
|
||||
if param_name in ("self", "cls"):
|
||||
for arg_i, arg in enumerate(signature.args):
|
||||
param_name = arg.name
|
||||
if arg_i == 0 and param_name in ("self", "cls"):
|
||||
continue
|
||||
annotation_str = "object"
|
||||
if arg.annotation:
|
||||
annotation_str = get_annotation_str(arg.annotation)
|
||||
param_type = _python_type_to_parameter_type(annotation_str)
|
||||
param_type = _python_type_to_parameter_type(arg.annotation)
|
||||
method_params.append(Parameter(Name=param_name, Type=param_type))
|
||||
|
||||
# Process keyword-only parameters
|
||||
for kwarg in stmt.args.kwonlyargs:
|
||||
has_keyword_args = True
|
||||
param_name = kwarg.arg
|
||||
annotation_str = "object"
|
||||
if kwarg.annotation:
|
||||
annotation_str = get_annotation_str(kwarg.annotation)
|
||||
param_type = _python_type_to_parameter_type(annotation_str)
|
||||
method_params.append(Parameter(Name=param_name, Type=param_type))
|
||||
|
||||
if stmt.args.kwarg:
|
||||
has_keyword_args = True
|
||||
|
||||
keyword_flag = has_keyword_args and not stmt.args.vararg
|
||||
|
||||
# Check for various decorators using any(...)
|
||||
const_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "constmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
static_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "staticmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
class_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "classmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
no_args = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "no_args" for deco in stmt.decorator_list
|
||||
)
|
||||
|
||||
methode = Methode(
|
||||
Name=method_name,
|
||||
method = Methode(
|
||||
Name=func.name,
|
||||
Documentation=doc_obj,
|
||||
Parameter=method_params,
|
||||
Const=const_method_flag,
|
||||
Static=static_method_flag,
|
||||
Class=class_method_flag,
|
||||
Keyword=keyword_flag,
|
||||
NoArgs=no_args,
|
||||
Const=func.const_flag,
|
||||
Static=func.static_flag,
|
||||
Class=func.class_flag,
|
||||
Keyword=func.has_keywords,
|
||||
NoArgs=func.noargs_flag,
|
||||
)
|
||||
|
||||
methods.append(methode)
|
||||
methods.append(method)
|
||||
|
||||
return methods
|
||||
|
||||
|
||||
Reference in New Issue
Block a user