[bindings] fix signatures in pyi files
This commit is contained in:
@@ -14,9 +14,174 @@ from model.typedModel import (
|
||||
Parameter,
|
||||
ParameterType,
|
||||
SequenceProtocol,
|
||||
FuncArgument,
|
||||
ArgumentKind,
|
||||
)
|
||||
|
||||
|
||||
class FunctionSignature:
|
||||
"""
|
||||
Parse function arguments with correct classification and order.
|
||||
"""
|
||||
|
||||
args: list[FuncArgument]
|
||||
has_keywords: bool
|
||||
docstring: str
|
||||
|
||||
const_flag: bool = False
|
||||
static_flag: bool = False
|
||||
class_flag: bool = False
|
||||
noargs_flag: bool = False
|
||||
is_overload: bool = False
|
||||
|
||||
def __init__(self, func: ast.FunctionDef):
|
||||
self.args = []
|
||||
self.has_keywords = False
|
||||
self.is_overload = False
|
||||
self.docstring = ast.get_docstring(func) or ""
|
||||
|
||||
args = func.args
|
||||
self.update_flags(func)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.PositionOnly,
|
||||
)
|
||||
for arg in args.posonlyargs
|
||||
),
|
||||
)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.Arg,
|
||||
)
|
||||
for arg in args.args
|
||||
),
|
||||
)
|
||||
|
||||
# tricky part to determine if there are keyword arguments or not
|
||||
if args.args:
|
||||
if args.args[0].arg in ("self", "cls"):
|
||||
instance_args = len(args.args) > 1
|
||||
else:
|
||||
instance_args = True
|
||||
else:
|
||||
instance_args = False
|
||||
|
||||
self.has_keywords = bool(instance_args or args.kwonlyargs or args.kwarg)
|
||||
|
||||
if args.vararg:
|
||||
self.args.append(
|
||||
FuncArgument(
|
||||
args.vararg.arg,
|
||||
self.get_annotation_str(args.vararg.annotation),
|
||||
ArgumentKind.VarArg,
|
||||
),
|
||||
)
|
||||
|
||||
self.args.extend(
|
||||
(
|
||||
FuncArgument(
|
||||
arg.arg,
|
||||
self.get_annotation_str(arg.annotation),
|
||||
ArgumentKind.KwOnly,
|
||||
)
|
||||
for arg in args.kwonlyargs
|
||||
),
|
||||
)
|
||||
|
||||
if args.kwarg:
|
||||
self.args.append(
|
||||
FuncArgument(
|
||||
args.kwarg.arg,
|
||||
self.get_annotation_str(args.kwarg.annotation),
|
||||
ArgumentKind.KwArg,
|
||||
),
|
||||
)
|
||||
|
||||
def get_annotation_str(self, node: ast.AST | None) -> str:
|
||||
if not node:
|
||||
return "object"
|
||||
return ast.unparse(node)
|
||||
|
||||
def update_flags(self, func: ast.FunctionDef) -> None:
|
||||
for deco in func.decorator_list:
|
||||
match deco:
|
||||
case ast.Name(id, _):
|
||||
name = id
|
||||
case ast.Attribute(_, attr, _):
|
||||
name = attr
|
||||
case _:
|
||||
continue
|
||||
|
||||
match name:
|
||||
case "constmethod":
|
||||
self.const_flag = True
|
||||
case "classmethod":
|
||||
self.class_flag = True
|
||||
case "no_args":
|
||||
self.noargs_flag = True
|
||||
case "staticmethod":
|
||||
self.static_flag = True
|
||||
case "overload":
|
||||
self.is_overload = True
|
||||
|
||||
|
||||
class Function:
|
||||
name: str
|
||||
signatures: list[FunctionSignature]
|
||||
|
||||
def __init__(self, func: ast.FunctionDef) -> None:
|
||||
self.name = func.name
|
||||
self.signatures = [FunctionSignature(func)]
|
||||
|
||||
def update(self, func: ast.FunctionDef) -> None:
|
||||
self.signatures.append(FunctionSignature(func))
|
||||
|
||||
@property
|
||||
def docstring(self) -> str:
|
||||
return "\n".join((f.docstring for f in self.signatures))
|
||||
|
||||
@property
|
||||
def has_keywords(self) -> bool:
|
||||
overloads = len(self.signatures) > 1
|
||||
if overloads:
|
||||
return any(
|
||||
sig.has_keywords for sig in self.signatures if sig.is_overload
|
||||
)
|
||||
return self.signatures[0].has_keywords
|
||||
|
||||
@property
|
||||
def signature(self) -> FunctionSignature | None:
|
||||
"""First non overload signature"""
|
||||
for sig in self.signatures:
|
||||
if not sig.is_overload:
|
||||
return sig
|
||||
return None
|
||||
|
||||
@property
|
||||
def static_flag(self) -> bool:
|
||||
return any(sig.static_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def const_flag(self) -> bool:
|
||||
return any(sig.const_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def class_flag(self) -> bool:
|
||||
return any(sig.class_flag for sig in self.signatures)
|
||||
|
||||
@property
|
||||
def noargs_flag(self) -> bool:
|
||||
return any(sig.noargs_flag for sig in self.signatures)
|
||||
|
||||
|
||||
def _extract_decorator_kwargs(decorator: ast.expr) -> dict:
|
||||
"""
|
||||
Extract keyword arguments from a decorator call like `@export(Father="...", Name="...")`.
|
||||
@@ -146,7 +311,9 @@ def _python_type_to_parameter_type(py_type: str) -> ParameterType:
|
||||
return ParameterType.OBJECT
|
||||
|
||||
|
||||
def _parse_class_attributes(class_node: ast.ClassDef, source_code: str) -> List[Attribute]:
|
||||
def _parse_class_attributes(
|
||||
class_node: ast.ClassDef, source_code: str
|
||||
) -> List[Attribute]:
|
||||
"""
|
||||
Parse top-level attributes (e.g. `TypeId: str = ""`) from the class AST node.
|
||||
We'll create an `Attribute` for each. For the `Documentation` of each attribute,
|
||||
@@ -158,7 +325,11 @@ def _parse_class_attributes(class_node: ast.ClassDef, source_code: str) -> List[
|
||||
for idx, stmt in enumerate(class_node.body):
|
||||
if isinstance(stmt, ast.AnnAssign):
|
||||
# e.g.: `TypeId: Final[str] = ""`
|
||||
name = stmt.target.id if isinstance(stmt.target, ast.Name) else "unknown"
|
||||
name = (
|
||||
stmt.target.id
|
||||
if isinstance(stmt.target, ast.Name)
|
||||
else "unknown"
|
||||
)
|
||||
# Evaluate the type annotation and detect Final for read-only attributes
|
||||
if isinstance(stmt.annotation, ast.Name):
|
||||
# e.g. `str`
|
||||
@@ -201,7 +372,12 @@ def _parse_class_attributes(class_node: ast.ClassDef, source_code: str) -> List[
|
||||
attr_doc = _parse_docstring_for_documentation(docstring)
|
||||
|
||||
param = Parameter(Name=name, Type=param_type)
|
||||
attr = Attribute(Documentation=attr_doc, Parameter=param, Name=name, ReadOnly=readonly)
|
||||
attr = Attribute(
|
||||
Documentation=attr_doc,
|
||||
Parameter=param,
|
||||
Name=name,
|
||||
ReadOnly=readonly,
|
||||
)
|
||||
attributes.append(attr)
|
||||
|
||||
return attributes
|
||||
@@ -216,7 +392,7 @@ def _parse_methods(class_node: ast.ClassDef) -> List[Methode]:
|
||||
"""
|
||||
methods = []
|
||||
|
||||
def collect_function_defs(nodes):
|
||||
def collect_function_defs(nodes) -> list[ast.FunctionDef]:
|
||||
funcs = []
|
||||
for node in nodes:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
@@ -226,99 +402,42 @@ def _parse_methods(class_node: ast.ClassDef) -> List[Methode]:
|
||||
funcs.extend(collect_function_defs(node.orelse))
|
||||
return funcs
|
||||
|
||||
for stmt in collect_function_defs(class_node.body):
|
||||
# Skip methods decorated with @overload
|
||||
skip_method = False
|
||||
for deco in stmt.decorator_list:
|
||||
match deco:
|
||||
case ast.Name(id="overload"):
|
||||
skip_method = True
|
||||
break
|
||||
case ast.Attribute(attr="overload"):
|
||||
skip_method = True
|
||||
break
|
||||
case _:
|
||||
pass
|
||||
if skip_method:
|
||||
continue
|
||||
# Collect including overloads
|
||||
functions: dict[str, Function] = {}
|
||||
for func_node in collect_function_defs(class_node.body):
|
||||
if func := functions.get(func_node.name):
|
||||
func.update(func_node)
|
||||
else:
|
||||
functions[func_node.name] = Function(func_node)
|
||||
|
||||
# Extract method name
|
||||
method_name = stmt.name
|
||||
|
||||
# Extract docstring
|
||||
method_docstring = ast.get_docstring(stmt) or ""
|
||||
doc_obj = _parse_docstring_for_documentation(method_docstring)
|
||||
has_keyword_args = False
|
||||
for func in functions.values():
|
||||
doc_obj = _parse_docstring_for_documentation(func.docstring)
|
||||
method_params = []
|
||||
|
||||
# Helper for extracting an annotation string
|
||||
def get_annotation_str(annotation):
|
||||
match annotation:
|
||||
case ast.Name(id=name):
|
||||
return name
|
||||
case ast.Attribute(value=ast.Name(id=name), attr=attr):
|
||||
return f"{name}.{attr}"
|
||||
case ast.Subscript(value=ast.Name(id=name), slice=_):
|
||||
return name
|
||||
case ast.Subscript(
|
||||
value=ast.Attribute(value=ast.Name(id=name), attr=attr), slice=_
|
||||
):
|
||||
return f"{name}.{attr}"
|
||||
case _:
|
||||
return "object"
|
||||
signature = func.signature
|
||||
if signature is None:
|
||||
continue
|
||||
|
||||
# Process positional parameters (skipping self/cls)
|
||||
for arg in stmt.args.args:
|
||||
param_name = arg.arg
|
||||
if param_name in ("self", "cls"):
|
||||
for arg_i, arg in enumerate(signature.args):
|
||||
param_name = arg.name
|
||||
if arg_i == 0 and param_name in ("self", "cls"):
|
||||
continue
|
||||
annotation_str = "object"
|
||||
if arg.annotation:
|
||||
annotation_str = get_annotation_str(arg.annotation)
|
||||
param_type = _python_type_to_parameter_type(annotation_str)
|
||||
param_type = _python_type_to_parameter_type(arg.annotation)
|
||||
method_params.append(Parameter(Name=param_name, Type=param_type))
|
||||
|
||||
# Process keyword-only parameters
|
||||
for kwarg in stmt.args.kwonlyargs:
|
||||
has_keyword_args = True
|
||||
param_name = kwarg.arg
|
||||
annotation_str = "object"
|
||||
if kwarg.annotation:
|
||||
annotation_str = get_annotation_str(kwarg.annotation)
|
||||
param_type = _python_type_to_parameter_type(annotation_str)
|
||||
method_params.append(Parameter(Name=param_name, Type=param_type))
|
||||
|
||||
if stmt.args.kwarg:
|
||||
has_keyword_args = True
|
||||
|
||||
keyword_flag = has_keyword_args and not stmt.args.vararg
|
||||
|
||||
# Check for various decorators using any(...)
|
||||
const_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "constmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
static_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "staticmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
class_method_flag = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "classmethod" for deco in stmt.decorator_list
|
||||
)
|
||||
no_args = any(
|
||||
isinstance(deco, ast.Name) and deco.id == "no_args" for deco in stmt.decorator_list
|
||||
)
|
||||
|
||||
methode = Methode(
|
||||
Name=method_name,
|
||||
method = Methode(
|
||||
Name=func.name,
|
||||
Documentation=doc_obj,
|
||||
Parameter=method_params,
|
||||
Const=const_method_flag,
|
||||
Static=static_method_flag,
|
||||
Class=class_method_flag,
|
||||
Keyword=keyword_flag,
|
||||
NoArgs=no_args,
|
||||
Const=func.const_flag,
|
||||
Static=func.static_flag,
|
||||
Class=func.class_flag,
|
||||
Keyword=func.has_keywords,
|
||||
NoArgs=func.noargs_flag,
|
||||
)
|
||||
|
||||
methods.append(methode)
|
||||
methods.append(method)
|
||||
|
||||
return methods
|
||||
|
||||
@@ -457,7 +576,9 @@ def _extract_base_class_name(base: ast.expr) -> str:
|
||||
return base_str
|
||||
|
||||
|
||||
def _parse_class(class_node, source_code: str, path: str, imports_mapping: dict) -> PythonExport:
|
||||
def _parse_class(
|
||||
class_node, source_code: str, path: str, imports_mapping: dict
|
||||
) -> PythonExport:
|
||||
base_class_name = None
|
||||
for base in class_node.bases:
|
||||
base_class_name = _extract_base_class_name(base)
|
||||
@@ -489,7 +610,9 @@ def _parse_class(class_node, source_code: str, path: str, imports_mapping: dict)
|
||||
match args[0]:
|
||||
case ast.Constant(value=val):
|
||||
class_declarations_text = val
|
||||
case ast.Call(func=ast.Name(id="sequence_protocol"), keywords=_, args=_):
|
||||
case ast.Call(
|
||||
func=ast.Name(id="sequence_protocol"), keywords=_, args=_
|
||||
):
|
||||
sequence_protocol_kwargs = _extract_decorator_kwargs(decorator)
|
||||
case _:
|
||||
pass
|
||||
@@ -509,23 +632,33 @@ def _parse_class(class_node, source_code: str, path: str, imports_mapping: dict)
|
||||
native_python_class_name = _get_native_python_class_name(class_node.name)
|
||||
include = _get_module_path(module_name) + "/" + native_class_name + ".h"
|
||||
|
||||
father_native_python_class_name = _get_native_python_class_name(base_class_name)
|
||||
father_native_python_class_name = _get_native_python_class_name(
|
||||
base_class_name
|
||||
)
|
||||
father_include = (
|
||||
_get_module_path(parent_module_name) + "/" + father_native_python_class_name + ".h"
|
||||
_get_module_path(parent_module_name)
|
||||
+ "/"
|
||||
+ father_native_python_class_name
|
||||
+ ".h"
|
||||
)
|
||||
|
||||
py_export = PythonExport(
|
||||
Documentation=doc_obj,
|
||||
ModuleName=module_name,
|
||||
Name=export_decorator_kwargs.get("Name", "") or native_python_class_name,
|
||||
Name=export_decorator_kwargs.get("Name", "")
|
||||
or native_python_class_name,
|
||||
PythonName=export_decorator_kwargs.get("PythonName", "") or None,
|
||||
Include=export_decorator_kwargs.get("Include", "") or include,
|
||||
Father=export_decorator_kwargs.get("Father", "") or father_native_python_class_name,
|
||||
Father=export_decorator_kwargs.get("Father", "")
|
||||
or father_native_python_class_name,
|
||||
Twin=export_decorator_kwargs.get("Twin", "") or native_class_name,
|
||||
TwinPointer=export_decorator_kwargs.get("TwinPointer", "") or native_class_name,
|
||||
TwinPointer=export_decorator_kwargs.get("TwinPointer", "")
|
||||
or native_class_name,
|
||||
Namespace=export_decorator_kwargs.get("Namespace", "") or module_name,
|
||||
FatherInclude=export_decorator_kwargs.get("FatherInclude", "") or father_include,
|
||||
FatherNamespace=export_decorator_kwargs.get("FatherNamespace", "") or parent_module_name,
|
||||
FatherInclude=export_decorator_kwargs.get("FatherInclude", "")
|
||||
or father_include,
|
||||
FatherNamespace=export_decorator_kwargs.get("FatherNamespace", "")
|
||||
or parent_module_name,
|
||||
Constructor=export_decorator_kwargs.get("Constructor", False),
|
||||
NumberProtocol=export_decorator_kwargs.get("NumberProtocol", False),
|
||||
RichCompare=export_decorator_kwargs.get("RichCompare", False),
|
||||
|
||||
@@ -326,3 +326,19 @@ class GenerateModel:
|
||||
# Each method might have parameters
|
||||
for param in meth.Parameter:
|
||||
print(f" * param: {param.Name}, type={param.Type}")
|
||||
|
||||
# Rich Modules
|
||||
|
||||
class ArgumentKind(Enum):
|
||||
PositionOnly = 0
|
||||
Arg = 1
|
||||
VarArg = 2
|
||||
KwOnly = 3
|
||||
KwArg = 4
|
||||
|
||||
@dataclass
|
||||
class FuncArgument:
|
||||
name: str
|
||||
annotation: str
|
||||
kind: ArgumentKind
|
||||
|
||||
|
||||
Reference in New Issue
Block a user