diff --git a/.gitea/workflows/ci.yaml b/.gitea/workflows/ci.yaml new file mode 100644 index 0000000..2dc42d0 --- /dev/null +++ b/.gitea/workflows/ci.yaml @@ -0,0 +1,65 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install ruff mypy + pip install -e ".[dev]" || pip install ruff mypy numpy + + - name: Ruff check + run: ruff check solver/ freecad/ tests/ scripts/ + + - name: Ruff format check + run: ruff format --check solver/ freecad/ tests/ scripts/ + + type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install mypy numpy + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install torch-geometric + pip install -e ".[dev]" + + - name: Mypy + run: mypy solver/ freecad/ + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install torch-geometric + pip install -e ".[train,dev]" + + - name: Run tests + run: pytest tests/ freecad/tests/ -v --tb=short diff --git a/.gitignore b/.gitignore index 2d5c3f3..26fcec5 100644 --- a/.gitignore +++ b/.gitignore @@ -1,44 +1,83 @@ -# Prerequisites +# C++ compiled objects *.d - -# Compiled Object files *.slo *.lo *.o *.obj - -# Precompiled Headers *.gch *.pch -# Compiled Dynamic libraries +# C++ libraries *.so *.dylib *.dll - -# Fortran module files -*.mod -*.smod - -# Compiled Static libraries *.lai *.la *.a *.lib -# Executables +# C++ executables *.exe *.out *.app -.vs +# C++ build +build/ +cmake-build-debug/ +.vs/ x64/ +temp/ + +# OndselSolver test artifacts *.bak assembly.asmt - -build -cmake-build-debug -.idea -temp/ /testapp/draggingBackhoe.log /testapp/runPreDragBackhoe.asmt + +# Python +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +*.egg + +# Virtual environments +.venv/ +venv/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# mypy / ruff / pytest +.mypy_cache/ +.ruff_cache/ +.pytest_cache/ + +# Data (large files tracked separately) +data/synthetic/*.pt +data/fusion360/*.json +data/fusion360/*.step +data/processed/*.pt +!data/**/.gitkeep + +# Model checkpoints +*.ckpt +*.pth +*.onnx +*.torchscript + +# Experiment tracking +wandb/ +runs/ + +# OS +.DS_Store +Thumbs.db + +# Environment +.env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..69b8611 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.4 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + additional_dependencies: + - torch>=2.2 + - numpy>=1.26 + args: [--ignore-missing-imports] + + - repo: https://github.com/compilerla/conventional-pre-commit + rev: v3.1.0 + hooks: + - id: conventional-pre-commit + stages: [commit-msg] + args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ebbca0a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 + +# System deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 python3.11-venv python3.11-dev python3-pip \ + git wget curl \ + # FreeCAD headless deps + freecad \ + libgl1-mesa-glx libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 + +# Create venv +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install PyTorch with CUDA +RUN pip install --no-cache-dir \ + torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 + +# Install PyG +RUN pip install --no-cache-dir \ + torch-geometric \ + pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \ + -f https://data.pyg.org/whl/torch-2.4.0+cu124.html + +WORKDIR /workspace + +# Install project +COPY pyproject.toml . +RUN pip install --no-cache-dir -e ".[train,dev]" || true + +COPY . . +RUN pip install --no-cache-dir -e ".[train,dev]" + +# ------------------------------------------------------------------- +FROM base AS cpu + +# CPU-only variant (for CI and non-GPU environments) +FROM python:3.11-slim AS cpu-only + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git freecad libgl1-mesa-glx libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY pyproject.toml . +RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache-dir torch-geometric + +COPY . . +RUN pip install --no-cache-dir -e ".[train,dev]" + +CMD ["pytest", "tests/", "-v"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..18d1309 --- /dev/null +++ b/Makefile @@ -0,0 +1,48 @@ +.PHONY: train test lint data-gen export format type-check install dev clean help + +PYTHON ?= python +PYTEST ?= pytest +RUFF ?= ruff +MYPY ?= mypy + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \ + awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +install: ## Install core dependencies + pip install -e . + +dev: ## Install all dependencies including dev tools + pip install -e ".[train,dev]" + pre-commit install + pre-commit install --hook-type commit-msg + +train: ## Run training (pass CONFIG=path/to/config.yaml) + $(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG)) + +test: ## Run test suite + $(PYTEST) tests/ freecad/tests/ -v --tb=short + +lint: ## Run ruff linter + $(RUFF) check solver/ freecad/ tests/ scripts/ + +format: ## Format code with ruff + $(RUFF) format solver/ freecad/ tests/ scripts/ + $(RUFF) check --fix solver/ freecad/ tests/ scripts/ + +type-check: ## Run mypy type checker + $(MYPY) solver/ freecad/ + +data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml) + $(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG)) + +export: ## Export trained model for deployment + $(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL)) + +clean: ## Remove build artifacts and caches + rm -rf build/ dist/ *.egg-info/ + rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/ + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + +check: lint type-check test ## Run all checks (lint, type-check, test) diff --git a/README.md b/README.md index 2ceda78..17e5095 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,91 @@ -# MbDCode -Assembly Constraints and Multibody Dynamics code +# Kindred Solver -Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited) +Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer. -The MbD theory is at -https://github.com/Ondsel-Development/MbDTheory +## Components + +### OndselSolver (C++) + +Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver. + +- Source: `OndselSolver/` +- Entry point: `OndselSolverMain/` +- Tests: `tests/`, `testapp/` +- Build: CMake + +**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory) + +#### Building + +```bash +cmake -B build -DCMAKE_BUILD_TYPE=Release +cmake --build build +``` + +### ML Solver Layer (Python) + +Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns. + +- Core library: `solver/` +- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling) +- Model architectures: `solver/models/` (GIN, GAT, NNConv) +- Training: `solver/training/` +- Inference: `solver/inference/` +- FreeCAD integration: `freecad/` +- Configuration: `configs/` (Hydra) + +#### Setup + +```bash +pip install -e ".[train,dev]" +pre-commit install +``` + +#### Usage + +```bash +make help # show all targets +make dev # install all deps + pre-commit hooks +make test # run tests +make lint # run ruff linter +make check # lint + type-check + test +make data-gen # generate synthetic data +make train # run training +make export # export model +``` + +Docker is also supported: + +```bash +docker compose up train # GPU training +docker compose up test # run tests +docker compose up data-gen # generate synthetic data +``` + +## Repository structure + +``` +kindred-solver/ +├── OndselSolver/ # C++ numerical solver library +├── OndselSolverMain/ # C++ solver CLI entry point +├── tests/ # C++ unit tests + Python tests +├── testapp/ # C++ test application +├── solver/ # Python ML solver library +│ ├── datagen/ # Synthetic data generation (pebble game) +│ ├── datasets/ # PyG dataset adapters +│ ├── models/ # GNN architectures +│ ├── training/ # Training loops +│ ├── evaluation/ # Metrics and visualization +│ └── inference/ # Runtime prediction API +├── freecad/ # FreeCAD workbench integration +├── configs/ # Hydra configs (dataset, model, training, export) +├── scripts/ # CLI utilities +├── data/ # Datasets (not committed) +├── export/ # Model packaging +└── docs/ # Documentation +``` + +## License + +OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE)) +ML Solver Layer: Apache-2.0 diff --git a/configs/dataset/fusion360.yaml b/configs/dataset/fusion360.yaml new file mode 100644 index 0000000..68b7b21 --- /dev/null +++ b/configs/dataset/fusion360.yaml @@ -0,0 +1,12 @@ +# Fusion 360 Gallery dataset config +name: fusion360 +data_dir: data/fusion360 +output_dir: data/processed + +splits: + train: 0.8 + val: 0.1 + test: 0.1 + +stratify_by: complexity +seed: 42 diff --git a/configs/dataset/synthetic.yaml b/configs/dataset/synthetic.yaml new file mode 100644 index 0000000..5f8c117 --- /dev/null +++ b/configs/dataset/synthetic.yaml @@ -0,0 +1,26 @@ +# Synthetic dataset generation config +name: synthetic +num_assemblies: 100000 +output_dir: data/synthetic +shard_size: 1000 + +complexity_distribution: + simple: 0.4 # 2-5 bodies + medium: 0.4 # 6-15 bodies + complex: 0.2 # 16-50 bodies + +body_count: + min: 2 + max: 50 + +templates: + - chain + - tree + - loop + - star + - mixed + +grounded_ratio: 0.5 +seed: 42 +num_workers: 4 +checkpoint_every: 5 diff --git a/configs/export/production.yaml b/configs/export/production.yaml new file mode 100644 index 0000000..1be79ec --- /dev/null +++ b/configs/export/production.yaml @@ -0,0 +1,25 @@ +# Production model export config +model_checkpoint: checkpoints/finetune/best_val_loss.ckpt +output_dir: export/ + +formats: + onnx: + enabled: true + opset_version: 17 + dynamic_axes: true + torchscript: + enabled: true + +model_card: + version: "0.1.0" + architecture: baseline + training_data: + - synthetic_100k + - fusion360_gallery + +size_budget_mb: 50 + +inference: + device: cpu + batch_size: 1 + confidence_threshold: 0.8 diff --git a/configs/model/baseline.yaml b/configs/model/baseline.yaml new file mode 100644 index 0000000..ebefe9b --- /dev/null +++ b/configs/model/baseline.yaml @@ -0,0 +1,24 @@ +# Baseline GIN model config +name: baseline +architecture: gin + +encoder: + num_layers: 3 + hidden_dim: 128 + dropout: 0.1 + +node_features_dim: 22 +edge_features_dim: 22 + +heads: + edge_classification: + enabled: true + hidden_dim: 64 + graph_classification: + enabled: true + num_classes: 4 # rigid, under, over, mixed + joint_type: + enabled: true + num_classes: 12 + dof_regression: + enabled: true diff --git a/configs/model/gat.yaml b/configs/model/gat.yaml new file mode 100644 index 0000000..6592bc1 --- /dev/null +++ b/configs/model/gat.yaml @@ -0,0 +1,28 @@ +# Advanced GAT model config +name: gat_solver +architecture: gat + +encoder: + num_layers: 4 + hidden_dim: 256 + num_heads: 8 + dropout: 0.1 + residual: true + +node_features_dim: 22 +edge_features_dim: 22 + +heads: + edge_classification: + enabled: true + hidden_dim: 128 + graph_classification: + enabled: true + num_classes: 4 + joint_type: + enabled: true + num_classes: 12 + dof_regression: + enabled: true + dof_tracking: + enabled: true diff --git a/configs/training/finetune.yaml b/configs/training/finetune.yaml new file mode 100644 index 0000000..09e8bae --- /dev/null +++ b/configs/training/finetune.yaml @@ -0,0 +1,45 @@ +# Fine-tuning on real data config +phase: finetune + +dataset: fusion360 +model: baseline + +pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt + +optimizer: + name: adamw + lr: 1e-5 + weight_decay: 1e-4 + +scheduler: + name: cosine_annealing + T_max: 50 + eta_min: 1e-7 + +training: + epochs: 50 + batch_size: 32 + gradient_clip: 1.0 + early_stopping_patience: 10 + amp: true + freeze_encoder: false # set true for frozen encoder experiment + +loss: + edge_weight: 1.0 + graph_weight: 0.5 + joint_type_weight: 0.3 + dof_weight: 0.2 + redundant_penalty: 2.0 + +checkpointing: + save_best_val_loss: true + save_best_val_accuracy: true + save_every_n_epochs: 5 + checkpoint_dir: checkpoints/finetune + +logging: + backend: wandb + project: kindred-solver + log_every_n_steps: 20 + +seed: 42 diff --git a/configs/training/pretrain.yaml b/configs/training/pretrain.yaml new file mode 100644 index 0000000..90cd232 --- /dev/null +++ b/configs/training/pretrain.yaml @@ -0,0 +1,42 @@ +# Synthetic pre-training config +phase: pretrain + +dataset: synthetic +model: baseline + +optimizer: + name: adamw + lr: 1e-3 + weight_decay: 1e-4 + +scheduler: + name: cosine_annealing + T_max: 100 + eta_min: 1e-6 + +training: + epochs: 100 + batch_size: 64 + gradient_clip: 1.0 + early_stopping_patience: 10 + amp: true + +loss: + edge_weight: 1.0 + graph_weight: 0.5 + joint_type_weight: 0.3 + dof_weight: 0.2 + redundant_penalty: 2.0 # safety loss multiplier + +checkpointing: + save_best_val_loss: true + save_best_val_accuracy: true + save_every_n_epochs: 10 + checkpoint_dir: checkpoints/pretrain + +logging: + backend: wandb # or tensorboard + project: kindred-solver + log_every_n_steps: 50 + +seed: 42 diff --git a/data/fusion360/.gitkeep b/data/fusion360/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/processed/.gitkeep b/data/processed/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/splits/.gitkeep b/data/splits/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/synthetic/.gitkeep b/data/synthetic/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..dfe8b4c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,39 @@ +services: + train: + build: + context: . + dockerfile: Dockerfile + target: base + volumes: + - .:/workspace + - ./data:/workspace/data + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + command: make train + environment: + - CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} + - WANDB_API_KEY=${WANDB_API_KEY:-} + + test: + build: + context: . + dockerfile: Dockerfile + target: cpu-only + volumes: + - .:/workspace + command: make check + + data-gen: + build: + context: . + dockerfile: Dockerfile + target: base + volumes: + - .:/workspace + - ./data:/workspace/data + command: make data-gen diff --git a/docs/.gitkeep b/docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/export/.gitkeep b/export/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/freecad/__init__.py b/freecad/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/bridge/__init__.py b/freecad/bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/tests/__init__.py b/freecad/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/workbench/__init__.py b/freecad/workbench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..439159d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,97 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "kindred-solver" +version = "0.1.0" +description = "Assembly constraint prediction via GNN for Kindred Create" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.11" +authors = [ + { name = "Kindred Systems" }, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", +] +dependencies = [ + "torch>=2.2", + "torch-geometric>=2.5", + "numpy>=1.26", + "scipy>=1.12", +] + +[project.optional-dependencies] +train = [ + "wandb>=0.16", + "tensorboard>=2.16", + "hydra-core>=1.3", + "omegaconf>=2.3", + "matplotlib>=3.8", + "networkx>=3.2", +] +freecad = [ + "pyside6>=6.6", +] +dev = [ + "pytest>=8.0", + "pytest-cov>=4.1", + "ruff>=0.3", + "mypy>=1.8", + "pre-commit>=3.6", +] + +[project.urls] +Repository = "https://git.kindred-systems.com/kindred/solver" + +[tool.hatch.build.targets.wheel] +packages = ["solver", "freecad"] + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "TCH", # flake8-type-checking + "RUF", # ruff-specific +] + +[tool.ruff.lint.isort] +known-first-party = ["solver", "freecad"] + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "torch.*", + "torch_geometric.*", + "scipy.*", + "wandb.*", + "hydra.*", + "omegaconf.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests", "freecad/tests"] +addopts = "-v --tb=short" diff --git a/scripts/generate_synthetic.py b/scripts/generate_synthetic.py new file mode 100644 index 0000000..9a7f99f --- /dev/null +++ b/scripts/generate_synthetic.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Generate synthetic assembly dataset for kindred-solver training. + +Usage (argparse fallback — always available):: + + python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4 + +Usage (Hydra — when hydra-core is installed):: + + python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4 +""" + +from __future__ import annotations + + +def _try_hydra_main() -> bool: + """Attempt to run via Hydra. Returns *True* if Hydra handled it.""" + try: + import hydra # type: ignore[import-untyped] + from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped] + except ImportError: + return False + + @hydra.main( + config_path="../configs/dataset", + config_name="synthetic", + version_base=None, + ) + def _run(cfg: DictConfig) -> None: # type: ignore[type-arg] + from solver.datagen.dataset import DatasetConfig, DatasetGenerator + + config_dict = OmegaConf.to_container(cfg, resolve=True) + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + _run() # type: ignore[no-untyped-call] + return True + + +def _argparse_main() -> None: + """Fallback CLI using argparse.""" + import argparse + + parser = argparse.ArgumentParser( + description="Generate synthetic assembly dataset", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to YAML config file (optional)", + ) + parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory") + parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard") + parser.add_argument("--body-count-min", type=int, default=None, help="Min body count") + parser.add_argument("--body-count-max", type=int, default=None, help="Max body count") + parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers") + parser.add_argument( + "--checkpoint-every", + type=int, + default=None, + help="Checkpoint interval (shards)", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Do not resume from existing checkpoints", + ) + + args = parser.parse_args() + + from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + parse_simple_yaml, + ) + + config_dict: dict[str, object] = {} + if args.config: + config_dict = parse_simple_yaml(args.config) # type: ignore[assignment] + + # CLI args override config file (only when explicitly provided) + _override_map = { + "num_assemblies": args.num_assemblies, + "output_dir": args.output_dir, + "shard_size": args.shard_size, + "body_count_min": args.body_count_min, + "body_count_max": args.body_count_max, + "grounded_ratio": args.grounded_ratio, + "seed": args.seed, + "num_workers": args.num_workers, + "checkpoint_every": args.checkpoint_every, + } + for key, val in _override_map.items(): + if val is not None: + config_dict[key] = val + + if args.no_resume: + config_dict["resume"] = False + + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + +def main() -> None: + """Entry point: try Hydra first, fall back to argparse.""" + if not _try_hydra_main(): + _argparse_main() + + +if __name__ == "__main__": + main() diff --git a/solver/__init__.py b/solver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py new file mode 100644 index 0000000..7e26e21 --- /dev/null +++ b/solver/datagen/__init__.py @@ -0,0 +1,37 @@ +"""Data generation utilities for assembly constraint training data.""" + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.dataset import DatasetConfig, DatasetGenerator +from solver.datagen.generator import ( + COMPLEXITY_RANGES, + AxisStrategy, + SyntheticAssemblyGenerator, +) +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.labeling import AssemblyLabels, label_assembly +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + PebbleState, + RigidBody, +) + +__all__ = [ + "COMPLEXITY_RANGES", + "AssemblyLabels", + "AxisStrategy", + "ConstraintAnalysis", + "DatasetConfig", + "DatasetGenerator", + "JacobianVerifier", + "Joint", + "JointType", + "PebbleGame3D", + "PebbleState", + "RigidBody", + "SyntheticAssemblyGenerator", + "analyze_assembly", + "label_assembly", +] diff --git a/solver/datagen/analysis.py b/solver/datagen/analysis.py new file mode 100644 index 0000000..6ad8491 --- /dev/null +++ b/solver/datagen/analysis.py @@ -0,0 +1,140 @@ +"""Combined pebble game + Jacobian verification analysis. + +Provides :func:`analyze_assembly`, the main entry point for full rigidity +analysis of an assembly using both combinatorial and numerical methods. +""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +__all__ = ["analyze_assembly"] + +_GROUND_ID = -1 + + +def analyze_assembly( + bodies: list[RigidBody], + joints: list[Joint], + ground_body: int | None = None, +) -> ConstraintAnalysis: + """Full rigidity analysis of an assembly using both methods. + + Args: + bodies: List of rigid bodies in the assembly. + joints: List of joints connecting bodies. + ground_body: If set, this body is fixed (adds 6 implicit constraints). + + Returns: + ConstraintAnalysis with combinatorial and numerical results. + """ + # --- Pebble Game --- + pg = PebbleGame3D() + all_edge_results = [] + + # Add a virtual ground body (id=-1) if grounding is requested. + # Grounding body X means adding a fixed joint between X and + # the virtual ground. This properly lets the pebble game account + # for the 6 removed DOF without breaking invariants. + if ground_body is not None: + pg.add_body(_GROUND_ID) + + for body in bodies: + pg.add_body(body.body_id) + + if ground_body is not None: + ground_joint = Joint( + joint_id=-1, + body_a=ground_body, + body_b=_GROUND_ID, + joint_type=JointType.FIXED, + anchor_a=bodies[0].position if bodies else np.zeros(3), + anchor_b=bodies[0].position if bodies else np.zeros(3), + ) + pg.add_joint(ground_joint) + # Don't include ground joint edges in the output labels + # (they're infrastructure, not user constraints) + + for joint in joints: + results = pg.add_joint(joint) + all_edge_results.extend(results) + + combinatorial_independent = len(pg.state.independent_edges) + grounded = ground_body is not None + + # The virtual ground body contributes 6 pebbles to the total. + # Subtract those from the reported DOF for user-facing numbers. + raw_dof = pg.get_dof() + ground_offset = 6 if grounded else 0 + effective_dof = raw_dof - ground_offset + effective_internal_dof = max(0, effective_dof - (0 if grounded else 6)) + + # Classify based on effective (adjusted) DOF, not raw pebble game output, + # because the virtual ground body skews the raw numbers. + redundant = pg.get_redundant_count() + if redundant > 0 and effective_internal_dof > 0: + combinatorial_classification = "mixed" + elif redundant > 0: + combinatorial_classification = "overconstrained" + elif effective_internal_dof > 0: + combinatorial_classification = "underconstrained" + else: + combinatorial_classification = "well-constrained" + + # --- Jacobian Verification --- + verifier = JacobianVerifier(bodies) + + for joint in joints: + verifier.add_joint_constraints(joint) + + # If grounded, remove the ground body's columns (fix its DOF) + j = verifier.get_jacobian() + if ground_body is not None and j.size > 0: + idx = verifier.body_index[ground_body] + cols_to_remove = list(range(idx * 6, (idx + 1) * 6)) + j = np.delete(j, cols_to_remove, axis=1) + + if j.size > 0: + sv = np.linalg.svd(j, compute_uv=False) + jacobian_rank = int(np.sum(sv > 1e-8)) + else: + jacobian_rank = 0 + + n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies) + jacobian_nullity = n_cols - jacobian_rank + + dependent = verifier.find_dependencies() + + # Adjust for ground + trivial_dof = 0 if ground_body is not None else 6 + jacobian_internal_dof = jacobian_nullity - trivial_dof + + geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank) + + # Rigidity: numerically rigid if nullity == trivial DOF + is_rigid = jacobian_nullity <= trivial_dof + is_minimally_rigid = is_rigid and len(dependent) == 0 + + return ConstraintAnalysis( + combinatorial_dof=effective_dof, + combinatorial_internal_dof=effective_internal_dof, + combinatorial_redundant=pg.get_redundant_count(), + combinatorial_classification=combinatorial_classification, + per_edge_results=all_edge_results, + jacobian_rank=jacobian_rank, + jacobian_nullity=jacobian_nullity, + jacobian_internal_dof=max(0, jacobian_internal_dof), + numerically_dependent=dependent, + geometric_degeneracies=geometric_degeneracies, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + ) diff --git a/solver/datagen/dataset.py b/solver/datagen/dataset.py new file mode 100644 index 0000000..e5a4ce2 --- /dev/null +++ b/solver/datagen/dataset.py @@ -0,0 +1,624 @@ +"""Dataset generation orchestrator with sharding, checkpointing, and statistics. + +Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator` +for parallel generation of synthetic assembly training data. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import math +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "DatasetConfig", + "DatasetGenerator", + "parse_simple_yaml", +] + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class DatasetConfig: + """Configuration for synthetic dataset generation.""" + + name: str = "synthetic" + num_assemblies: int = 100_000 + output_dir: str = "data/synthetic" + shard_size: int = 1000 + complexity_distribution: dict[str, float] = field( + default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2} + ) + body_count_min: int = 2 + body_count_max: int = 50 + templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"]) + grounded_ratio: float = 0.5 + seed: int = 42 + num_workers: int = 4 + checkpoint_every: int = 5 + resume: bool = True + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> DatasetConfig: + """Construct from a parsed config dict (e.g. YAML or OmegaConf). + + Handles both flat keys (``body_count_min``) and nested forms + (``body_count: {min: 2, max: 50}``). + """ + kw: dict[str, Any] = {} + for key in ( + "name", + "num_assemblies", + "output_dir", + "shard_size", + "grounded_ratio", + "seed", + "num_workers", + "checkpoint_every", + "resume", + ): + if key in d: + kw[key] = d[key] + + # Handle nested body_count dict + if "body_count" in d and isinstance(d["body_count"], dict): + bc = d["body_count"] + if "min" in bc: + kw["body_count_min"] = int(bc["min"]) + if "max" in bc: + kw["body_count_max"] = int(bc["max"]) + else: + if "body_count_min" in d: + kw["body_count_min"] = int(d["body_count_min"]) + if "body_count_max" in d: + kw["body_count_max"] = int(d["body_count_max"]) + + if "complexity_distribution" in d: + cd = d["complexity_distribution"] + if isinstance(cd, dict): + kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()} + if "templates" in d and isinstance(d["templates"], list): + kw["templates"] = [str(t) for t in d["templates"]] + + return cls(**kw) + + +# --------------------------------------------------------------------------- +# Shard specification / result +# --------------------------------------------------------------------------- + + +@dataclass +class ShardSpec: + """Specification for generating a single shard.""" + + shard_id: int + start_example_id: int + count: int + seed: int + complexity_distribution: dict[str, float] + body_count_min: int + body_count_max: int + grounded_ratio: float + + +@dataclass +class ShardResult: + """Result returned from a shard worker.""" + + shard_id: int + num_examples: int + file_path: str + generation_time_s: float + + +# --------------------------------------------------------------------------- +# Seed derivation +# --------------------------------------------------------------------------- + + +def _derive_shard_seed(global_seed: int, shard_id: int) -> int: + """Derive a deterministic per-shard seed from the global seed.""" + h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest() + return int(h[:8], 16) + + +# --------------------------------------------------------------------------- +# Progress display +# --------------------------------------------------------------------------- + + +class _PrintProgress: + """Fallback progress display when tqdm is unavailable.""" + + def __init__(self, total: int) -> None: + self.total = total + self.current = 0 + self.start_time = time.monotonic() + + def update(self, n: int = 1) -> None: + self.current += n + elapsed = time.monotonic() - self.start_time + rate = self.current / elapsed if elapsed > 0 else 0.0 + eta = (self.total - self.current) / rate if rate > 0 else 0.0 + pct = 100.0 * self.current / self.total + sys.stdout.write( + f"\r[{pct:5.1f}%] {self.current}/{self.total} shards" + f" | {rate:.1f} shards/s | ETA: {eta:.0f}s" + ) + sys.stdout.flush() + + def close(self) -> None: + sys.stdout.write("\n") + sys.stdout.flush() + + +def _make_progress(total: int) -> _PrintProgress: + """Create a progress tracker (tqdm if available, else print-based).""" + try: + from tqdm import tqdm # type: ignore[import-untyped] + + return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value] + except ImportError: + return _PrintProgress(total) + + +# --------------------------------------------------------------------------- +# Shard I/O +# --------------------------------------------------------------------------- + + +def _save_shard( + shard_id: int, + examples: list[dict[str, Any]], + shards_dir: Path, +) -> Path: + """Save a shard to disk (.pt if torch available, else .json).""" + shards_dir.mkdir(parents=True, exist_ok=True) + try: + import torch # type: ignore[import-untyped] + + path = shards_dir / f"shard_{shard_id:05d}.pt" + torch.save(examples, path) + except ImportError: + path = shards_dir / f"shard_{shard_id:05d}.json" + with open(path, "w") as f: + json.dump(examples, f) + return path + + +def _load_shard(path: Path) -> list[dict[str, Any]]: + """Load a shard from disk (.pt or .json).""" + if path.suffix == ".pt": + import torch # type: ignore[import-untyped] + + result: list[dict[str, Any]] = torch.load(path, weights_only=False) + return result + with open(path) as f: + result = json.load(f) + return result + + +def _shard_format() -> str: + """Return the shard file extension based on available libraries.""" + try: + import torch # type: ignore[import-untyped] # noqa: F401 + + return ".pt" + except ImportError: + return ".json" + + +# --------------------------------------------------------------------------- +# Shard worker (module-level for pickling) +# --------------------------------------------------------------------------- + + +def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult: + """Generate a single shard — top-level function for ProcessPoolExecutor.""" + from solver.datagen.generator import SyntheticAssemblyGenerator + + t0 = time.monotonic() + gen = SyntheticAssemblyGenerator(seed=spec.seed) + rng = np.random.default_rng(spec.seed + 1) + + tiers = list(spec.complexity_distribution.keys()) + probs_list = [spec.complexity_distribution[t] for t in tiers] + total = sum(probs_list) + probs = [p / total for p in probs_list] + + examples: list[dict[str, Any]] = [] + for i in range(spec.count): + tier_idx = int(rng.choice(len(tiers), p=probs)) + tier = tiers[tier_idx] + try: + batch = gen.generate_training_batch( + batch_size=1, + complexity_tier=tier, # type: ignore[arg-type] + grounded_ratio=spec.grounded_ratio, + ) + ex = batch[0] + ex["example_id"] = spec.start_example_id + i + ex["complexity_tier"] = tier + examples.append(ex) + except Exception: + logger.warning( + "Shard %d, example %d failed — skipping", + spec.shard_id, + i, + exc_info=True, + ) + + shards_dir = Path(output_dir) / "shards" + path = _save_shard(spec.shard_id, examples, shards_dir) + + elapsed = time.monotonic() - t0 + return ShardResult( + shard_id=spec.shard_id, + num_examples=len(examples), + file_path=str(path), + generation_time_s=elapsed, + ) + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +def _parse_scalar(value: str) -> int | float | bool | str: + """Parse a YAML scalar value.""" + # Strip inline comments (space + #) + if " #" in value: + value = value[: value.index(" #")].strip() + elif " #" in value: + value = value[: value.index(" #")].strip() + v = value.strip() + if v.lower() in ("true", "yes"): + return True + if v.lower() in ("false", "no"): + return False + try: + return int(v) + except ValueError: + pass + try: + return float(v) + except ValueError: + pass + return v.strip("'\"") + + +def parse_simple_yaml(path: str) -> dict[str, Any]: + """Parse a simple YAML file (flat scalars, one-level dicts, lists). + + This is **not** a full YAML parser. It handles the structure of + ``configs/dataset/synthetic.yaml``. + """ + result: dict[str, Any] = {} + current_key: str | None = None + + with open(path) as f: + for raw_line in f: + line = raw_line.rstrip() + + # Skip blank lines and full-line comments + if not line or line.lstrip().startswith("#"): + continue + + indent = len(line) - len(line.lstrip()) + + if indent == 0 and ":" in line: + key, _, value = line.partition(":") + key = key.strip() + value = value.strip() + if value: + result[key] = _parse_scalar(value) + current_key = None + else: + current_key = key + result[key] = {} + continue + + if indent > 0 and line.lstrip().startswith("- "): + item = line.lstrip()[2:].strip() + if current_key is not None: + if isinstance(result.get(current_key), dict) and not result[current_key]: + result[current_key] = [] + if isinstance(result.get(current_key), list): + result[current_key].append(_parse_scalar(item)) + continue + + if indent > 0 and ":" in line and current_key is not None: + k, _, v = line.partition(":") + k = k.strip() + v = v.strip() + if v: + # Strip inline comments + if " #" in v: + v = v[: v.index(" #")].strip() + if not isinstance(result.get(current_key), dict): + result[current_key] = {} + result[current_key][k] = _parse_scalar(v) + continue + + return result + + +# --------------------------------------------------------------------------- +# Dataset generator orchestrator +# --------------------------------------------------------------------------- + + +class DatasetGenerator: + """Orchestrates parallel dataset generation with sharding and checkpointing.""" + + def __init__(self, config: DatasetConfig) -> None: + self.config = config + self.output_path = Path(config.output_dir) + self.shards_dir = self.output_path / "shards" + self.checkpoint_file = self.output_path / ".checkpoint.json" + self.index_file = self.output_path / "index.json" + self.stats_file = self.output_path / "stats.json" + + # -- public API -- + + def run(self) -> None: + """Generate the full dataset.""" + self.output_path.mkdir(parents=True, exist_ok=True) + self.shards_dir.mkdir(parents=True, exist_ok=True) + + shards = self._plan_shards() + total_shards = len(shards) + + # Resume: find already-completed shards + completed: set[int] = set() + if self.config.resume: + completed = self._find_completed_shards() + + pending = [s for s in shards if s.shard_id not in completed] + + if not pending: + logger.info("All %d shards already complete.", total_shards) + else: + logger.info( + "Generating %d shards (%d already complete).", + len(pending), + len(completed), + ) + progress = _make_progress(len(pending)) + workers = max(1, self.config.num_workers) + checkpoint_counter = 0 + + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit( + _generate_shard_worker, + spec, + str(self.output_path), + ): spec.shard_id + for spec in pending + } + for future in as_completed(futures): + shard_id = futures[future] + try: + result = future.result() + completed.add(result.shard_id) + logger.debug( + "Shard %d: %d examples in %.1fs", + result.shard_id, + result.num_examples, + result.generation_time_s, + ) + except Exception: + logger.error("Shard %d failed", shard_id, exc_info=True) + progress.update(1) + checkpoint_counter += 1 + if checkpoint_counter >= self.config.checkpoint_every: + self._update_checkpoint(completed, total_shards) + checkpoint_counter = 0 + + progress.close() + + # Finalize + self._build_index() + stats = self._compute_statistics() + self._write_statistics(stats) + self._print_summary(stats) + + # Remove checkpoint (generation complete) + if self.checkpoint_file.exists(): + self.checkpoint_file.unlink() + + # -- internal helpers -- + + def _plan_shards(self) -> list[ShardSpec]: + """Divide num_assemblies into shards.""" + n = self.config.num_assemblies + size = self.config.shard_size + num_shards = math.ceil(n / size) + shards: list[ShardSpec] = [] + for i in range(num_shards): + start = i * size + count = min(size, n - start) + shards.append( + ShardSpec( + shard_id=i, + start_example_id=start, + count=count, + seed=_derive_shard_seed(self.config.seed, i), + complexity_distribution=dict(self.config.complexity_distribution), + body_count_min=self.config.body_count_min, + body_count_max=self.config.body_count_max, + grounded_ratio=self.config.grounded_ratio, + ) + ) + return shards + + def _find_completed_shards(self) -> set[int]: + """Scan shards directory for existing shard files.""" + completed: set[int] = set() + if not self.shards_dir.exists(): + return completed + + for p in self.shards_dir.iterdir(): + if p.stem.startswith("shard_"): + try: + shard_id = int(p.stem.split("_")[1]) + # Verify file is non-empty + if p.stat().st_size > 0: + completed.add(shard_id) + except (ValueError, IndexError): + pass + return completed + + def _update_checkpoint(self, completed: set[int], total_shards: int) -> None: + """Write checkpoint file.""" + data = { + "completed_shards": sorted(completed), + "total_shards": total_shards, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + with open(self.checkpoint_file, "w") as f: + json.dump(data, f) + + def _build_index(self) -> None: + """Build index.json mapping shard files to assembly ID ranges.""" + shards_info: dict[str, dict[str, int]] = {} + total_assemblies = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + try: + shard_id = int(p.stem.split("_")[1]) + except (ValueError, IndexError): + continue + examples = _load_shard(p) + count = len(examples) + start_id = shard_id * self.config.shard_size + shards_info[p.name] = {"start_id": start_id, "count": count} + total_assemblies += count + + fmt = _shard_format().lstrip(".") + index = { + "format_version": 1, + "total_assemblies": total_assemblies, + "total_shards": len(shards_info), + "shard_format": fmt, + "shards": shards_info, + } + with open(self.index_file, "w") as f: + json.dump(index, f, indent=2) + + def _compute_statistics(self) -> dict[str, Any]: + """Aggregate statistics across all shards.""" + classification_counts: dict[str, int] = {} + body_count_hist: dict[int, int] = {} + joint_type_counts: dict[str, int] = {} + dof_values: list[int] = [] + degeneracy_values: list[int] = [] + rigid_count = 0 + minimally_rigid_count = 0 + total = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + examples = _load_shard(p) + for ex in examples: + total += 1 + cls = str(ex.get("assembly_classification", "unknown")) + classification_counts[cls] = classification_counts.get(cls, 0) + 1 + nb = int(ex.get("n_bodies", 0)) + body_count_hist[nb] = body_count_hist.get(nb, 0) + 1 + for j in ex.get("joints", []): + jt = str(j.get("type", "unknown")) + joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1 + dof_values.append(int(ex.get("internal_dof", 0))) + degeneracy_values.append(int(ex.get("geometric_degeneracies", 0))) + if ex.get("is_rigid"): + rigid_count += 1 + if ex.get("is_minimally_rigid"): + minimally_rigid_count += 1 + + dof_arr = np.array(dof_values) if dof_values else np.zeros(1) + deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1) + + return { + "total_examples": total, + "classification_distribution": dict(sorted(classification_counts.items())), + "body_count_histogram": dict(sorted(body_count_hist.items())), + "joint_type_distribution": dict(sorted(joint_type_counts.items())), + "dof_statistics": { + "mean": float(dof_arr.mean()), + "std": float(dof_arr.std()), + "min": int(dof_arr.min()), + "max": int(dof_arr.max()), + "median": float(np.median(dof_arr)), + }, + "geometric_degeneracy": { + "assemblies_with_degeneracy": int(np.sum(deg_arr > 0)), + "fraction_with_degeneracy": float(np.mean(deg_arr > 0)), + "mean_degeneracies": float(deg_arr.mean()), + }, + "rigidity": { + "rigid_count": rigid_count, + "rigid_fraction": (rigid_count / total if total > 0 else 0.0), + "minimally_rigid_count": minimally_rigid_count, + "minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0), + }, + } + + def _write_statistics(self, stats: dict[str, Any]) -> None: + """Write stats.json.""" + with open(self.stats_file, "w") as f: + json.dump(stats, f, indent=2) + + def _print_summary(self, stats: dict[str, Any]) -> None: + """Print a human-readable summary to stdout.""" + print("\n=== Dataset Generation Summary ===") + print(f"Total examples: {stats['total_examples']}") + print(f"Output directory: {self.output_path}") + print() + print("Classification distribution:") + for cls, count in stats["classification_distribution"].items(): + frac = count / max(stats["total_examples"], 1) * 100 + print(f" {cls}: {count} ({frac:.1f}%)") + print() + print("Joint type distribution:") + for jt, count in stats["joint_type_distribution"].items(): + print(f" {jt}: {count}") + print() + dof = stats["dof_statistics"] + print( + f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]" + ) + rig = stats["rigidity"] + print( + f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} " + f"({rig['rigid_fraction'] * 100:.1f}%) rigid, " + f"{rig['minimally_rigid_count']} minimally rigid" + ) + deg = stats["geometric_degeneracy"] + print( + f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies " + f"({deg['fraction_with_degeneracy'] * 100:.1f}%)" + ) diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py new file mode 100644 index 0000000..ad753f9 --- /dev/null +++ b/solver/datagen/generator.py @@ -0,0 +1,893 @@ +"""Synthetic assembly graph generator for training data production. + +Generates assembly graphs with known constraint classifications using +the pebble game and Jacobian verification. Each assembly is fully labeled +with per-constraint independence flags and assembly-level classification. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Literal + +import numpy as np +from scipy.spatial.transform import Rotation + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.labeling import label_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "COMPLEXITY_RANGES", + "AxisStrategy", + "ComplexityTier", + "SyntheticAssemblyGenerator", +] + +# --------------------------------------------------------------------------- +# Complexity tiers — ranges use exclusive upper bound for rng.integers() +# --------------------------------------------------------------------------- + +ComplexityTier = Literal["simple", "medium", "complex"] + +COMPLEXITY_RANGES: dict[str, tuple[int, int]] = { + "simple": (2, 6), + "medium": (6, 16), + "complex": (16, 51), +} + +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + +AxisStrategy = Literal["cardinal", "random", "near_parallel"] + + +class SyntheticAssemblyGenerator: + """Generates assembly graphs with known minimal constraint sets. + + Uses the pebble game to incrementally build assemblies, tracking + exactly which constraints are independent at each step. This produces + labeled training data: (assembly_graph, constraint_set, labels). + + Labels per constraint: + - independent: bool (does this constraint remove a DOF?) + - redundant: bool (is this constraint overconstrained?) + - minimal_set: bool (part of a minimal rigidity basis?) + """ + + def __init__(self, seed: int = 42) -> None: + self.rng = np.random.default_rng(seed) + + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _random_position(self, scale: float = 5.0) -> np.ndarray: + """Generate random 3D position within [-scale, scale] cube.""" + return self.rng.uniform(-scale, scale, size=3) + + def _random_axis(self) -> np.ndarray: + """Generate random normalized 3D axis.""" + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + return axis + + def _random_orientation(self) -> np.ndarray: + """Generate a random 3x3 rotation matrix.""" + mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix() + return mat + + def _cardinal_axis(self) -> np.ndarray: + """Pick uniformly from the six signed cardinal directions.""" + axes = np.array( + [ + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, -1], + ], + dtype=float, + ) + result: np.ndarray = axes[int(self.rng.integers(6))] + return result + + def _near_parallel_axis( + self, + base_axis: np.ndarray, + perturbation_scale: float = 0.05, + ) -> np.ndarray: + """Return *base_axis* with a small random perturbation, re-normalized.""" + perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale + return perturbed / np.linalg.norm(perturbed) + + def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray: + """Sample a joint axis using the specified strategy.""" + if strategy == "cardinal": + return self._cardinal_axis() + if strategy == "near_parallel": + return self._near_parallel_axis(np.array([0.0, 0.0, 1.0])) + return self._random_axis() + + def _resolve_axis( + self, + strategy: AxisStrategy, + parallel_axis_prob: float, + shared_axis: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray | None]: + """Return (axis_for_this_joint, shared_axis_to_propagate). + + On the first call where *shared_axis* is ``None`` and parallel + injection triggers, a base axis is chosen and returned as + *shared_axis* for subsequent calls. + """ + if shared_axis is not None: + return self._near_parallel_axis(shared_axis), shared_axis + if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob: + base = self._sample_axis(strategy) + return base.copy(), base + return self._sample_axis(strategy), None + + def _select_joint_type( + self, + joint_types: JointType | list[JointType], + ) -> JointType: + """Select a joint type from a single type or list.""" + if isinstance(joint_types, list): + idx = int(self.rng.integers(0, len(joint_types))) + return joint_types[idx] + return joint_types + + def _create_joint( + self, + joint_id: int, + body_a_id: int, + body_b_id: int, + pos_a: np.ndarray, + pos_b: np.ndarray, + joint_type: JointType, + *, + axis: np.ndarray | None = None, + orient_a: np.ndarray | None = None, + orient_b: np.ndarray | None = None, + ) -> Joint: + """Create a joint between two bodies. + + When orientations are provided, anchor points are offset from + each body's center along a random local direction rotated into + world frame, rather than placed at the midpoint. + """ + if orient_a is not None and orient_b is not None: + dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1) + offset_scale = dist * 0.2 + local_a = self.rng.standard_normal(3) * offset_scale + local_b = self.rng.standard_normal(3) * offset_scale + anchor_a = pos_a + orient_a @ local_a + anchor_b = pos_b + orient_b @ local_b + else: + anchor = (pos_a + pos_b) / 2.0 + anchor_a = anchor + anchor_b = anchor + + return Joint( + joint_id=joint_id, + body_a=body_a_id, + body_b=body_b_id, + joint_type=joint_type, + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis if axis is not None else self._random_axis(), + ) + + # ------------------------------------------------------------------ + # Original generators (chain / rigid / overconstrained) + # ------------------------------------------------------------------ + + def generate_chain_assembly( + self, + n_bodies: int, + joint_type: JointType = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a serial kinematic chain. + + Simple but useful: each body connects to the next with the + specified joint type. Results in an underconstrained assembly + (serial chain is never rigid without closing loops). + """ + bodies = [] + joints = [] + + for i in range(n_bodies): + pos = np.array([i * 2.0, 0.0, 0.0]) + bodies.append( + RigidBody( + body_id=i, + position=pos, + orientation=self._random_orientation(), + ) + ) + + shared_axis: np.ndarray | None = None + for i in range(n_bodies - 1): + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i, + i, + i + 1, + bodies[i].position, + bodies[i + 1].position, + joint_type, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[i + 1].orientation, + ) + ) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) + return bodies, joints, analysis + + def generate_rigid_assembly( + self, + n_bodies: int, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a minimally rigid assembly by adding joints until rigid. + + Strategy: start with fixed joints on a spanning tree (guarantees + rigidity), then randomly relax some to weaker joint types while + maintaining rigidity via the pebble game check. + """ + bodies = [] + for i in range(n_bodies): + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) + + # Build spanning tree with fixed joints (overconstrained) + joints: list[Joint] = [] + shared_axis: np.ndarray | None = None + for i in range(1, n_bodies): + parent = int(self.rng.integers(0, i)) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i - 1, + parent, + i, + bodies[parent].position, + bodies[i].position, + JointType.FIXED, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, + ) + ) + + # Try relaxing joints to weaker types while maintaining rigidity + weaker_types = [ + JointType.REVOLUTE, + JointType.CYLINDRICAL, + JointType.BALL, + ] + + ground = 0 if grounded else None + for idx in self.rng.permutation(len(joints)): + original_type = joints[idx].joint_type + for candidate in weaker_types: + joints[idx].joint_type = candidate + analysis = analyze_assembly(bodies, joints, ground_body=ground) + if analysis.is_rigid: + break # Keep the weaker type + else: + joints[idx].joint_type = original_type + + analysis = analyze_assembly(bodies, joints, ground_body=ground) + return bodies, joints, analysis + + def generate_overconstrained_assembly( + self, + n_bodies: int, + extra_joints: int = 2, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate an assembly with known redundant constraints. + + Starts with a rigid assembly, then adds extra joints that + the pebble game will flag as redundant. + """ + bodies, joints, _ = self.generate_rigid_assembly( + n_bodies, + grounded=grounded, + axis_strategy=axis_strategy, + parallel_axis_prob=parallel_axis_prob, + ) + + joint_id = len(joints) + shared_axis: np.ndarray | None = None + for _ in range(extra_joints): + a, b = self.rng.choice(n_bodies, size=2, replace=False) + _overcon_types = [ + JointType.REVOLUTE, + JointType.FIXED, + JointType.BALL, + ] + jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))] + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + joint_id, + int(a), + int(b), + bodies[int(a)].position, + bodies[int(b)].position, + jtype, + axis=axis, + orient_a=bodies[int(a)].orientation, + orient_b=bodies[int(b)].orientation, + ) + ) + joint_id += 1 + + ground = 0 if grounded else None + analysis = analyze_assembly(bodies, joints, ground_body=ground) + return bodies, joints, analysis + + # ------------------------------------------------------------------ + # New topology generators + # ------------------------------------------------------------------ + + def generate_tree_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + branching_factor: int = 3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a random tree topology with configurable branching. + + Creates a tree where each body can have up to *branching_factor* + children. Different branches can use different joint types if a + list is provided. Always underconstrained (no closed loops). + + Args: + n_bodies: Total bodies (root + children). + joint_types: Single type or list to sample from per joint. + branching_factor: Max children per parent (1-5 recommended). + """ + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=self._random_orientation(), + ) + ] + joints: list[Joint] = [] + + available_parents = [0] + next_id = 1 + joint_id = 0 + shared_axis: np.ndarray | None = None + + while next_id < n_bodies and available_parents: + pidx = int(self.rng.integers(0, len(available_parents))) + parent_id = available_parents[pidx] + parent_pos = bodies[parent_id].position + + max_children = min(branching_factor, n_bodies - next_id) + n_children = int(self.rng.integers(1, max_children + 1)) + + for _ in range(n_children): + direction = self._random_axis() + distance = self.rng.uniform(1.5, 3.0) + child_pos = parent_pos + direction * distance + child_orient = self._random_orientation() + + bodies.append( + RigidBody( + body_id=next_id, + position=child_pos, + orientation=child_orient, + ) + ) + jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + joint_id, + parent_id, + next_id, + parent_pos, + child_pos, + jtype, + axis=axis, + orient_a=bodies[parent_id].orientation, + orient_b=child_orient, + ) + ) + + available_parents.append(next_id) + next_id += 1 + joint_id += 1 + if next_id >= n_bodies: + break + + # Retire parent if it reached branching limit or randomly + if n_children >= branching_factor or self.rng.random() < 0.3: + available_parents.pop(pidx) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) + return bodies, joints, analysis + + def generate_loop_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a single closed loop (ring) of bodies. + + The closing constraint introduces redundancy, making this + useful for generating overconstrained training data. + + Args: + n_bodies: Bodies in the ring (>= 3). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 3. + """ + if n_bodies < 3: + msg = "Loop assembly requires at least 3 bodies" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + base_radius = max(2.0, n_bodies * 0.4) + for i in range(n_bodies): + angle = 2 * np.pi * i / n_bodies + radius = base_radius + self.rng.uniform(-0.5, 0.5) + x = radius * np.cos(angle) + y = radius * np.sin(angle) + z = float(self.rng.uniform(-0.2, 0.2)) + bodies.append( + RigidBody( + body_id=i, + position=np.array([x, y, z]), + orientation=self._random_orientation(), + ) + ) + + shared_axis: np.ndarray | None = None + for i in range(n_bodies): + next_i = (i + 1) % n_bodies + jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i, + i, + next_i, + bodies[i].position, + bodies[next_i].position, + jtype, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[next_i].orientation, + ) + ) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) + return bodies, joints, analysis + + def generate_star_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a star topology with central hub and satellites. + + Body 0 is the hub; all other bodies connect directly to it. + Underconstrained because there are no inter-satellite connections. + + Args: + n_bodies: Total bodies including hub (>= 2). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 2. + """ + if n_bodies < 2: + msg = "Star assembly requires at least 2 bodies" + raise ValueError(msg) + + hub_orient = self._random_orientation() + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=hub_orient, + ) + ] + joints: list[Joint] = [] + + shared_axis: np.ndarray | None = None + for i in range(1, n_bodies): + direction = self._random_axis() + distance = self.rng.uniform(2.0, 5.0) + pos = direction * distance + sat_orient = self._random_orientation() + bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient)) + + jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i - 1, + 0, + i, + np.zeros(3), + pos, + jtype, + axis=axis, + orient_a=hub_orient, + orient_b=sat_orient, + ) + ) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) + return bodies, joints, analysis + + def generate_mixed_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + edge_density: float = 0.3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a mixed topology combining tree and loop elements. + + Builds a spanning tree for connectivity, then adds extra edges + based on *edge_density* to create loops and redundancy. + + Args: + n_bodies: Number of bodies. + joint_types: Single type or list to sample from per joint. + edge_density: Fraction of non-tree edges to add (0.0-1.0). + + Raises: + ValueError: If *edge_density* not in [0.0, 1.0]. + """ + if not 0.0 <= edge_density <= 1.0: + msg = "edge_density must be in [0.0, 1.0]" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + for i in range(n_bodies): + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) + + # Phase 1: spanning tree + joint_id = 0 + existing_edges: set[frozenset[int]] = set() + shared_axis: np.ndarray | None = None + for i in range(1, n_bodies): + parent = int(self.rng.integers(0, i)) + jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + joint_id, + parent, + i, + bodies[parent].position, + bodies[i].position, + jtype, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, + ) + ) + existing_edges.add(frozenset([parent, i])) + joint_id += 1 + + # Phase 2: extra edges based on density + candidates: list[tuple[int, int]] = [] + for i in range(n_bodies): + for j in range(i + 1, n_bodies): + if frozenset([i, j]) not in existing_edges: + candidates.append((i, j)) + + n_extra = int(edge_density * len(candidates)) + self.rng.shuffle(candidates) + + for a, b in candidates[:n_extra]: + jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + joint_id, + a, + b, + bodies[a].position, + bodies[b].position, + jtype, + axis=axis, + orient_a=bodies[a].orientation, + orient_b=bodies[b].orientation, + ) + ) + joint_id += 1 + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) + return bodies, joints, analysis + + # ------------------------------------------------------------------ + # Batch generation + # ------------------------------------------------------------------ + + def generate_training_batch( + self, + batch_size: int = 100, + n_bodies_range: tuple[int, int] | None = None, + complexity_tier: ComplexityTier | None = None, + *, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + grounded_ratio: float = 1.0, + ) -> list[dict[str, Any]]: + """Generate a batch of labeled training examples. + + Each example contains body positions, joint descriptions, + per-joint independence labels, and assembly-level classification. + + Args: + batch_size: Number of assemblies to generate. + n_bodies_range: ``(min, max_exclusive)`` body count. + Overridden by *complexity_tier* when both are given. + complexity_tier: Predefined range (``"simple"`` / ``"medium"`` + / ``"complex"``). Overrides *n_bodies_range*. + axis_strategy: Axis sampling strategy for joint axes. + parallel_axis_prob: Probability of parallel axis injection. + grounded_ratio: Fraction of examples that are grounded. + """ + if complexity_tier is not None: + n_bodies_range = COMPLEXITY_RANGES[complexity_tier] + elif n_bodies_range is None: + n_bodies_range = (3, 8) + + _joint_pool = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + JointType.FIXED, + ] + + geo_kw: dict[str, Any] = { + "axis_strategy": axis_strategy, + "parallel_axis_prob": parallel_axis_prob, + } + + examples: list[dict[str, Any]] = [] + + for i in range(batch_size): + n = int(self.rng.integers(*n_bodies_range)) + gen_idx = int(self.rng.integers(7)) + grounded = bool(self.rng.random() < grounded_ratio) + + if gen_idx == 0: + _chain_types = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + ] + jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] + bodies, joints, analysis = self.generate_chain_assembly( + n, + jtype, + grounded=grounded, + **geo_kw, + ) + gen_name = "chain" + elif gen_idx == 1: + bodies, joints, analysis = self.generate_rigid_assembly( + n, + grounded=grounded, + **geo_kw, + ) + gen_name = "rigid" + elif gen_idx == 2: + extra = int(self.rng.integers(1, 4)) + bodies, joints, analysis = self.generate_overconstrained_assembly( + n, + extra, + grounded=grounded, + **geo_kw, + ) + gen_name = "overconstrained" + elif gen_idx == 3: + branching = int(self.rng.integers(2, 5)) + bodies, joints, analysis = self.generate_tree_assembly( + n, + _joint_pool, + branching, + grounded=grounded, + **geo_kw, + ) + gen_name = "tree" + elif gen_idx == 4: + n = max(n, 3) + bodies, joints, analysis = self.generate_loop_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) + gen_name = "loop" + elif gen_idx == 5: + n = max(n, 2) + bodies, joints, analysis = self.generate_star_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) + gen_name = "star" + else: + density = float(self.rng.uniform(0.2, 0.5)) + bodies, joints, analysis = self.generate_mixed_assembly( + n, + _joint_pool, + density, + grounded=grounded, + **geo_kw, + ) + gen_name = "mixed" + + # Produce ground truth labels (includes ConstraintAnalysis) + ground = 0 if grounded else None + labels = label_assembly(bodies, joints, ground_body=ground) + analysis = labels.analysis + + # Build per-joint labels from edge results + joint_labels: dict[int, dict[str, int]] = {} + for result in analysis.per_edge_results: + jid = result["joint_id"] + if jid not in joint_labels: + joint_labels[jid] = { + "independent_constraints": 0, + "redundant_constraints": 0, + "total_constraints": 0, + } + joint_labels[jid]["total_constraints"] += 1 + if result["independent"]: + joint_labels[jid]["independent_constraints"] += 1 + else: + joint_labels[jid]["redundant_constraints"] += 1 + + examples.append( + { + "example_id": i, + "generator_type": gen_name, + "grounded": grounded, + "n_bodies": len(bodies), + "n_joints": len(joints), + "body_positions": [b.position.tolist() for b in bodies], + "body_orientations": [b.orientation.tolist() for b in bodies], + "joints": [ + { + "joint_id": j.joint_id, + "body_a": j.body_a, + "body_b": j.body_b, + "type": j.joint_type.name, + "axis": j.axis.tolist(), + } + for j in joints + ], + "joint_labels": joint_labels, + "labels": labels.to_dict(), + "assembly_classification": (analysis.combinatorial_classification), + "is_rigid": analysis.is_rigid, + "is_minimally_rigid": analysis.is_minimally_rigid, + "internal_dof": analysis.jacobian_internal_dof, + "geometric_degeneracies": (analysis.geometric_degeneracies), + } + ) + + return examples diff --git a/solver/datagen/jacobian.py b/solver/datagen/jacobian.py new file mode 100644 index 0000000..09bb769 --- /dev/null +++ b/solver/datagen/jacobian.py @@ -0,0 +1,517 @@ +"""Numerical Jacobian rank verification for assembly constraint analysis. + +Builds the constraint Jacobian matrix and analyzes its numerical rank +to detect geometric degeneracies that the combinatorial pebble game +cannot identify (e.g., parallel revolute axes creating hidden dependencies). + +References: + - Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D" +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.types import Joint, JointType, RigidBody + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["JacobianVerifier"] + + +class JacobianVerifier: + """Builds and analyzes the constraint Jacobian for numerical rank check. + + The pebble game gives a combinatorial *necessary* condition for + rigidity. However, geometric special cases (e.g., all revolute axes + parallel, creating a hidden dependency) require numerical verification. + + For each joint, we construct the constraint Jacobian rows that map + the 6n-dimensional generalized velocity vector to the constraint + violation rates. The rank of this Jacobian equals the number of + truly independent constraints. + + The generalized velocity vector for n bodies is:: + + v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z] + + Each scalar constraint C_i contributes one row to J such that:: + + dC_i/dt = J_i @ v = 0 + """ + + def __init__(self, bodies: list[RigidBody]) -> None: + self.bodies = {b.body_id: b for b in bodies} + self.body_index = {b.body_id: i for i, b in enumerate(bodies)} + self.n_bodies = len(bodies) + self.jacobian_rows: list[np.ndarray] = [] + self.row_labels: list[dict[str, Any]] = [] + + def _body_cols(self, body_id: int) -> tuple[int, int]: + """Return the column range [start, end) for a body in J.""" + idx = self.body_index[body_id] + return idx * 6, (idx + 1) * 6 + + def add_joint_constraints(self, joint: Joint) -> int: + """Add Jacobian rows for all scalar constraints of a joint. + + Returns the number of rows added. + """ + builder = { + JointType.FIXED: self._build_fixed, + JointType.REVOLUTE: self._build_revolute, + JointType.CYLINDRICAL: self._build_cylindrical, + JointType.SLIDER: self._build_slider, + JointType.BALL: self._build_ball, + JointType.PLANAR: self._build_planar, + JointType.DISTANCE: self._build_distance, + JointType.PARALLEL: self._build_parallel, + JointType.PERPENDICULAR: self._build_perpendicular, + JointType.UNIVERSAL: self._build_universal, + JointType.SCREW: self._build_screw, + } + + rows_before = len(self.jacobian_rows) + builder[joint.joint_type](joint) + return len(self.jacobian_rows) - rows_before + + def _make_row(self) -> np.ndarray: + """Create a zero row of width 6*n_bodies.""" + return np.zeros(6 * self.n_bodies) + + def _skew(self, v: np.ndarray) -> np.ndarray: + """Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``.""" + return np.array( + [ + [0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0], + ] + ) + + # --- Ball-and-socket (spherical) joint: 3 translation constraints --- + + def _build_ball(self, joint: Joint) -> None: + """Ball joint: coincident point constraint. + + ``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0`` + (3 equations) + + Jacobian rows (for each of x, y, z): + body_a linear: -I + body_a angular: +skew(R_a @ r_a) + body_b linear: +I + body_b angular: -skew(R_b @ r_b) + """ + # Use anchor positions directly as world-frame offsets + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + e = np.zeros(3) + e[axis_idx] = 1.0 + + row[col_a_start : col_a_start + 3] = -e + row[col_a_start + 3 : col_a_end] = np.cross(r_a, e) + + row[col_b_start : col_b_start + 3] = e + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "ball_translation", + "axis": axis_idx, + } + ) + + # --- Fixed joint: 3 translation + 3 rotation constraints --- + + def _build_fixed(self, joint: Joint) -> None: + """Fixed joint = ball joint + locked rotation. + + Translation part: same as ball joint (3 rows). + Rotation part: relative angular velocity must be zero (3 rows). + """ + self._build_ball(joint) + + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "fixed_rotation", + "axis": axis_idx, + } + ) + + # --- Revolute (hinge) joint: 3 translation + 2 rotation constraints --- + + def _build_revolute(self, joint: Joint) -> None: + """Revolute joint: rotation only about one axis. + + Translation: same as ball (3 rows). + Rotation: relative angular velocity must be parallel to hinge axis. + """ + self._build_ball(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "revolute_rotation", + "perp_axis": i, + } + ) + + # --- Cylindrical joint: 2 translation + 2 rotation constraints --- + + def _build_cylindrical(self, joint: Joint) -> None: + """Cylindrical joint: allows rotation + translation along one axis. + + Translation: constrain motion perpendicular to axis (2 rows). + Rotation: constrain rotation perpendicular to axis (2 rows). + """ + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start : col_a_start + 3] = -t + row[col_a_start + 3 : col_a_end] = np.cross(r_a, t) + row[col_b_start : col_b_start + 3] = t + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "cylindrical_translation", + "perp_axis": i, + } + ) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "cylindrical_rotation", + "perp_axis": i, + } + ) + + # --- Slider (prismatic) joint: 2 translation + 3 rotation constraints --- + + def _build_slider(self, joint: Joint) -> None: + """Slider/prismatic joint: translation along one axis only. + + Translation: perpendicular translation constrained (2 rows). + Rotation: all relative rotation constrained (3 rows). + """ + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start : col_a_start + 3] = -t + row[col_a_start + 3 : col_a_end] = np.cross(r_a, t) + row[col_b_start : col_b_start + 3] = t + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "slider_translation", + "perp_axis": i, + } + ) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "slider_rotation", + "axis": axis_idx, + } + ) + + # --- Planar joint: 1 translation + 2 rotation constraints --- + + def _build_planar(self, joint: Joint) -> None: + """Planar joint: constrains to a plane. + + Translation: motion along plane normal constrained (1 row). + Rotation: rotation about axes in the plane constrained (2 rows). + """ + normal = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(normal) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -normal + row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal) + row[col_b_start : col_b_start + 3] = normal + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal) + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "planar_translation", + } + ) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "planar_rotation", + "perp_axis": i, + } + ) + + # --- Distance constraint: 1 scalar --- + + def _build_distance(self, joint: Joint) -> None: + """Distance constraint: ``||p_b - p_a|| = d``. + + Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0`` + where ``direction = normalized(p_b - p_a)``. + """ + p_a = joint.anchor_a + p_b = joint.anchor_b + diff = p_b - p_a + dist = np.linalg.norm(diff) + direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -direction + row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction) + row[col_b_start : col_b_start + 3] = direction + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "distance", + } + ) + + # --- Parallel constraint: 3 rotation constraints --- + + def _build_parallel(self, joint: Joint) -> None: + """Parallel: all relative rotation constrained (same as fixed rotation). + + In practice only 2 of 3 are independent for a single axis, but + we emit 3 and let the rank check sort it out. + """ + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "parallel_rotation", + "axis": axis_idx, + } + ) + + # --- Perpendicular constraint: 1 angular --- + + def _build_perpendicular(self, joint: Joint) -> None: + """Perpendicular: single dot-product angular constraint.""" + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -axis + row[col_b_start + 3 : col_b_start + 6] = axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "perpendicular", + } + ) + + # --- Universal (Cardan) joint: 3 translation + 1 rotation --- + + def _build_universal(self, joint: Joint) -> None: + """Universal joint: ball + one rotation constraint. + + Allows rotation about two axes, constrains rotation about the third. + """ + self._build_ball(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -axis + row[col_b_start + 3 : col_b_start + 6] = axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "universal_rotation", + } + ) + + # --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled --- + + def _build_screw(self, joint: Joint) -> None: + """Screw joint: coupled rotation-translation along axis. + + Like cylindrical but with a coupling constraint: + ``v_axial - pitch * w_axial = 0`` + """ + self._build_cylindrical(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -axis + row[col_b_start : col_b_start + 3] = axis + row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis + row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "screw_coupling", + } + ) + + # --- Utilities --- + + def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Generate two unit vectors perpendicular to *axis* and each other.""" + if abs(axis[0]) < 0.9: + t1 = np.cross(axis, np.array([1.0, 0, 0])) + else: + t1 = np.cross(axis, np.array([0, 1.0, 0])) + t1 /= np.linalg.norm(t1) + t2 = np.cross(axis, t1) + t2 /= np.linalg.norm(t2) + return t1, t2 + + def get_jacobian(self) -> np.ndarray: + """Return the full constraint Jacobian matrix.""" + if not self.jacobian_rows: + return np.zeros((0, 6 * self.n_bodies)) + return np.array(self.jacobian_rows) + + def numerical_rank(self, tol: float = 1e-8) -> int: + """Compute the numerical rank of the constraint Jacobian via SVD. + + This is the number of truly independent scalar constraints, + accounting for geometric degeneracies that the combinatorial + pebble game cannot detect. + """ + j = self.get_jacobian() + if j.size == 0: + return 0 + sv = np.linalg.svd(j, compute_uv=False) + return int(np.sum(sv > tol)) + + def find_dependencies(self, tol: float = 1e-8) -> list[int]: + """Identify which constraint rows are numerically dependent. + + Returns indices of rows that can be removed without changing + the Jacobian's rank. + """ + j = self.get_jacobian() + if j.size == 0: + return [] + + n_rows = j.shape[0] + dependent: list[int] = [] + + current = np.zeros((0, j.shape[1])) + current_rank = 0 + + for i in range(n_rows): + candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :] + sv = np.linalg.svd(candidate, compute_uv=False) + new_rank = int(np.sum(sv > tol)) + + if new_rank > current_rank: + current = candidate + current_rank = new_rank + else: + dependent.append(i) + + return dependent diff --git a/solver/datagen/labeling.py b/solver/datagen/labeling.py new file mode 100644 index 0000000..e03de42 --- /dev/null +++ b/solver/datagen/labeling.py @@ -0,0 +1,394 @@ +"""Ground truth labeling pipeline for synthetic assemblies. + +Produces rich per-constraint, per-joint, per-body, and assembly-level +labels by running both the pebble game and Jacobian verification and +correlating their results. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["AssemblyLabels", "label_assembly"] + +_GROUND_ID = -1 +_SVD_TOL = 1e-8 + + +# --------------------------------------------------------------------------- +# Label dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ConstraintLabel: + """Per scalar-constraint label combining both analysis methods.""" + + joint_id: int + constraint_idx: int + pebble_independent: bool + jacobian_independent: bool + + +@dataclass +class JointLabel: + """Aggregated constraint counts for a single joint.""" + + joint_id: int + independent_count: int + redundant_count: int + total: int + + +@dataclass +class BodyDofLabel: + """Per-body DOF signature from nullspace projection.""" + + body_id: int + translational_dof: int + rotational_dof: int + + +@dataclass +class AssemblyLabel: + """Assembly-wide summary label.""" + + classification: str + total_dof: int + redundant_count: int + is_rigid: bool + is_minimally_rigid: bool + has_degeneracy: bool + + +@dataclass +class AssemblyLabels: + """Complete ground truth labels for an assembly.""" + + per_constraint: list[ConstraintLabel] + per_joint: list[JointLabel] + per_body: list[BodyDofLabel] + assembly: AssemblyLabel + analysis: ConstraintAnalysis + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "per_constraint": [ + { + "joint_id": c.joint_id, + "constraint_idx": c.constraint_idx, + "pebble_independent": c.pebble_independent, + "jacobian_independent": c.jacobian_independent, + } + for c in self.per_constraint + ], + "per_joint": [ + { + "joint_id": j.joint_id, + "independent_count": j.independent_count, + "redundant_count": j.redundant_count, + "total": j.total, + } + for j in self.per_joint + ], + "per_body": [ + { + "body_id": b.body_id, + "translational_dof": b.translational_dof, + "rotational_dof": b.rotational_dof, + } + for b in self.per_body + ], + "assembly": { + "classification": self.assembly.classification, + "total_dof": self.assembly.total_dof, + "redundant_count": self.assembly.redundant_count, + "is_rigid": self.assembly.is_rigid, + "is_minimally_rigid": self.assembly.is_minimally_rigid, + "has_degeneracy": self.assembly.has_degeneracy, + }, + } + + +# --------------------------------------------------------------------------- +# Per-body DOF from nullspace projection +# --------------------------------------------------------------------------- + + +def _compute_per_body_dof( + j_reduced: np.ndarray, + body_ids: list[int], + ground_body: int | None, + body_index: dict[int, int], +) -> list[BodyDofLabel]: + """Compute translational and rotational DOF per body. + + Uses SVD nullspace projection: for each body, extract its + translational (3 cols) and rotational (3 cols) components + from the nullspace basis and compute ranks. + """ + # Build column index mapping for the reduced Jacobian + # (ground body columns have been removed) + col_map: dict[int, int] = {} + col_idx = 0 + for bid in body_ids: + if bid == ground_body: + continue + col_map[bid] = col_idx + col_idx += 1 + + results: list[BodyDofLabel] = [] + + if j_reduced.size == 0: + # No constraints — every body is fully free + for bid in body_ids: + if bid == ground_body: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + else: + results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3)) + return results + + # Full SVD to get nullspace + _u, s, vh = np.linalg.svd(j_reduced, full_matrices=True) + rank = int(np.sum(s > _SVD_TOL)) + n_cols = j_reduced.shape[1] + + if rank >= n_cols: + # Fully constrained — no nullspace + for bid in body_ids: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + return results + + # Nullspace basis: rows of Vh beyond the rank + nullspace = vh[rank:] # shape: (n_cols - rank, n_cols) + + for bid in body_ids: + if bid == ground_body: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + continue + + idx = col_map[bid] + trans_cols = nullspace[:, idx * 6 : idx * 6 + 3] + rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6] + + # Rank of each block = DOF in that category + if trans_cols.size > 0: + sv_t = np.linalg.svd(trans_cols, compute_uv=False) + t_dof = int(np.sum(sv_t > _SVD_TOL)) + else: + t_dof = 0 + + if rot_cols.size > 0: + sv_r = np.linalg.svd(rot_cols, compute_uv=False) + r_dof = int(np.sum(sv_r > _SVD_TOL)) + else: + r_dof = 0 + + results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof)) + + return results + + +# --------------------------------------------------------------------------- +# Main labeling function +# --------------------------------------------------------------------------- + + +def label_assembly( + bodies: list[RigidBody], + joints: list[Joint], + ground_body: int | None = None, +) -> AssemblyLabels: + """Produce complete ground truth labels for an assembly. + + Runs both the pebble game and Jacobian verification internally, + then correlates their results into per-constraint, per-joint, + per-body, and assembly-level labels. + + Args: + bodies: Rigid bodies in the assembly. + joints: Joints connecting the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + AssemblyLabels with full label set and embedded ConstraintAnalysis. + """ + # ---- Pebble Game ---- + pg = PebbleGame3D() + all_edge_results: list[dict[str, Any]] = [] + + if ground_body is not None: + pg.add_body(_GROUND_ID) + + for body in bodies: + pg.add_body(body.body_id) + + if ground_body is not None: + ground_joint = Joint( + joint_id=-1, + body_a=ground_body, + body_b=_GROUND_ID, + joint_type=JointType.FIXED, + anchor_a=bodies[0].position if bodies else np.zeros(3), + anchor_b=bodies[0].position if bodies else np.zeros(3), + ) + pg.add_joint(ground_joint) + + for joint in joints: + results = pg.add_joint(joint) + all_edge_results.extend(results) + + grounded = ground_body is not None + combinatorial_independent = len(pg.state.independent_edges) + raw_dof = pg.get_dof() + ground_offset = 6 if grounded else 0 + effective_dof = raw_dof - ground_offset + effective_internal_dof = max(0, effective_dof - (0 if grounded else 6)) + + redundant_count = pg.get_redundant_count() + if redundant_count > 0 and effective_internal_dof > 0: + classification = "mixed" + elif redundant_count > 0: + classification = "overconstrained" + elif effective_internal_dof > 0: + classification = "underconstrained" + else: + classification = "well-constrained" + + # ---- Jacobian Verification ---- + verifier = JacobianVerifier(bodies) + + for joint in joints: + verifier.add_joint_constraints(joint) + + j_full = verifier.get_jacobian() + j_reduced = j_full.copy() + if ground_body is not None and j_reduced.size > 0: + idx = verifier.body_index[ground_body] + cols_to_remove = list(range(idx * 6, (idx + 1) * 6)) + j_reduced = np.delete(j_reduced, cols_to_remove, axis=1) + + if j_reduced.size > 0: + sv = np.linalg.svd(j_reduced, compute_uv=False) + jacobian_rank = int(np.sum(sv > _SVD_TOL)) + else: + jacobian_rank = 0 + + n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies) + jacobian_nullity = n_cols - jacobian_rank + dependent_rows = verifier.find_dependencies() + dependent_set = set(dependent_rows) + + trivial_dof = 0 if grounded else 6 + jacobian_internal_dof = jacobian_nullity - trivial_dof + geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank) + is_rigid = jacobian_nullity <= trivial_dof + is_minimally_rigid = is_rigid and len(dependent_rows) == 0 + + # ---- Per-constraint labels ---- + # Map Jacobian rows to (joint_id, constraint_index). + # Rows are added contiguously per joint in the same order as joints. + row_to_joint: list[tuple[int, int]] = [] + for joint in joints: + dof = joint.joint_type.dof + for ci in range(dof): + row_to_joint.append((joint.joint_id, ci)) + + per_constraint: list[ConstraintLabel] = [] + for edge_idx, edge_result in enumerate(all_edge_results): + jid = edge_result["joint_id"] + ci = edge_result["constraint_index"] + pebble_indep = edge_result["independent"] + + # Find matching Jacobian row + jacobian_indep = True + if edge_idx < len(row_to_joint): + row_idx = edge_idx + jacobian_indep = row_idx not in dependent_set + + per_constraint.append( + ConstraintLabel( + joint_id=jid, + constraint_idx=ci, + pebble_independent=pebble_indep, + jacobian_independent=jacobian_indep, + ) + ) + + # ---- Per-joint labels ---- + joint_agg: dict[int, JointLabel] = {} + for cl in per_constraint: + if cl.joint_id not in joint_agg: + joint_agg[cl.joint_id] = JointLabel( + joint_id=cl.joint_id, + independent_count=0, + redundant_count=0, + total=0, + ) + jl = joint_agg[cl.joint_id] + jl.total += 1 + if cl.pebble_independent: + jl.independent_count += 1 + else: + jl.redundant_count += 1 + + per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg] + + # ---- Per-body DOF labels ---- + body_ids = [b.body_id for b in bodies] + per_body = _compute_per_body_dof( + j_reduced, + body_ids, + ground_body, + verifier.body_index, + ) + + # ---- Assembly label ---- + assembly_label = AssemblyLabel( + classification=classification, + total_dof=max(0, jacobian_internal_dof), + redundant_count=redundant_count, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + has_degeneracy=geometric_degeneracies > 0, + ) + + # ---- ConstraintAnalysis (for backward compat) ---- + analysis = ConstraintAnalysis( + combinatorial_dof=effective_dof, + combinatorial_internal_dof=effective_internal_dof, + combinatorial_redundant=redundant_count, + combinatorial_classification=classification, + per_edge_results=all_edge_results, + jacobian_rank=jacobian_rank, + jacobian_nullity=jacobian_nullity, + jacobian_internal_dof=max(0, jacobian_internal_dof), + numerically_dependent=dependent_rows, + geometric_degeneracies=geometric_degeneracies, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + ) + + return AssemblyLabels( + per_constraint=per_constraint, + per_joint=per_joint, + per_body=per_body, + assembly=assembly_label, + analysis=analysis, + ) diff --git a/solver/datagen/pebble_game.py b/solver/datagen/pebble_game.py new file mode 100644 index 0000000..e3bdbc3 --- /dev/null +++ b/solver/datagen/pebble_game.py @@ -0,0 +1,258 @@ +"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis. + +Implements the pebble game algorithm adapted for CAD assembly constraint +graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints +between bodies remove DOF according to their type. + +The pebble game provides a fast combinatorial *necessary* condition for +rigidity via Tay's theorem. It does not detect geometric degeneracies — +use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient* +condition. + +References: + - Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008 + - Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity + Percolation: The Pebble Game", J. Comput. Phys., 1997 + - Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from solver.datagen.types import Joint, PebbleState + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["PebbleGame3D"] + + +class PebbleGame3D: + """Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks. + + For body-bar-hinge structures in 3D, Tay's theorem states that a + multigraph G on n vertices is generically minimally rigid iff: + |E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2) + + The (6,6)-pebble game tests this sparsity condition incrementally. + Each vertex starts with 6 pebbles (representing 6 DOF). To insert + an edge, we need to collect 6+1=7 pebbles on its two endpoints. + If we can, the edge is independent (removes a DOF). If not, it's + redundant (overconstrained). + + In the CAD assembly context: + - Vertices = rigid bodies + - Edges = scalar constraints from joints + - A revolute joint (5 DOF removed) maps to 5 multigraph edges + - A fixed joint (6 DOF removed) maps to 6 multigraph edges + """ + + K = 6 # Pebbles per vertex (DOF per rigid body in 3D) + L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge + + def __init__(self) -> None: + self.state = PebbleState() + self._edge_counter = 0 + self._bodies: set[int] = set() + + def add_body(self, body_id: int) -> None: + """Register a rigid body (vertex) with K=6 free pebbles.""" + if body_id in self._bodies: + return + self._bodies.add(body_id) + self.state.free_pebbles[body_id] = self.K + self.state.incoming[body_id] = set() + self.state.outgoing[body_id] = set() + + def add_joint(self, joint: Joint) -> list[dict[str, Any]]: + """Expand a joint into multigraph edges and test each for independence. + + A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph. + Each edge is tested individually via the pebble game. + + Returns a list of dicts, one per scalar constraint, with: + - edge_id: int + - independent: bool + - dof_remaining: int (total free pebbles after this edge) + """ + self.add_body(joint.body_a) + self.add_body(joint.body_b) + + num_constraints = joint.joint_type.dof + results: list[dict[str, Any]] = [] + + for i in range(num_constraints): + edge_id = self._edge_counter + self._edge_counter += 1 + + independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b) + total_free = sum(self.state.free_pebbles.values()) + + results.append( + { + "edge_id": edge_id, + "joint_id": joint.joint_id, + "constraint_index": i, + "independent": independent, + "dof_remaining": total_free, + } + ) + + return results + + def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool: + """Try to insert a directed edge between u and v. + + The edge is accepted (independent) iff we can collect L+1 = 7 + pebbles on the two endpoints {u, v} combined. + + If accepted, one pebble is consumed and the edge is directed + away from the vertex that gives up the pebble. + """ + # Count current free pebbles on u and v + available = self.state.free_pebbles[u] + self.state.free_pebbles[v] + + # Try to gather enough pebbles via DFS reachability search + if available < self.L + 1: + needed = (self.L + 1) - available + # Try to free pebbles by searching from u first, then v + for target in (u, v): + while needed > 0: + found = self._search_and_collect(target, frozenset({u, v})) + if not found: + break + needed -= 1 + + # Recheck after collection attempts + available = self.state.free_pebbles[u] + self.state.free_pebbles[v] + + if available >= self.L + 1: + # Accept: consume a pebble from whichever endpoint has one + source = u if self.state.free_pebbles[u] > 0 else v + + self.state.free_pebbles[source] -= 1 + self.state.directed_edges[edge_id] = (source, v if source == u else u) + self.state.outgoing[source].add((edge_id, v if source == u else u)) + target = v if source == u else u + self.state.incoming[target].add((edge_id, source)) + self.state.independent_edges.add(edge_id) + return True + else: + # Reject: edge is redundant (overconstrained) + self.state.redundant_edges.add(edge_id) + return False + + def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool: + """DFS to find a free pebble reachable from *target* and move it. + + Follows directed edges *backwards* (from destination to source) + to find a vertex with a free pebble that isn't in *forbidden*. + When found, reverses the path to move the pebble to *target*. + + Returns True if a pebble was successfully moved to target. + """ + # BFS/DFS through the directed graph following outgoing edges + # from target. An outgoing edge (target -> w) means target spent + # a pebble on that edge. If we can find a vertex with a free + # pebble, we reverse edges along the path to move it. + + visited: set[int] = set() + # Stack: (current_vertex, path_of_edge_ids_to_reverse) + stack: list[tuple[int, list[int]]] = [(target, [])] + + while stack: + current, path = stack.pop() + if current in visited: + continue + visited.add(current) + + # Check if current vertex (not in forbidden, not target) + # has a free pebble + if ( + current != target + and current not in forbidden + and self.state.free_pebbles[current] > 0 + ): + # Found a pebble — reverse the path + self._reverse_path(path, current) + return True + + # Follow outgoing edges from current vertex + for eid, neighbor in self.state.outgoing.get(current, set()): + if neighbor not in visited: + stack.append((neighbor, [*path, eid])) + + return False + + def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None: + """Reverse directed edges along a path, moving a pebble to the start. + + The pebble at *pebble_source* is consumed by the last edge in + the path, and a pebble is freed at the path's start vertex. + """ + if not edge_ids: + return + + # Reverse each edge in the path + for eid in edge_ids: + old_source, old_target = self.state.directed_edges[eid] + + # Remove from adjacency + self.state.outgoing[old_source].discard((eid, old_target)) + self.state.incoming[old_target].discard((eid, old_source)) + + # Reverse direction + self.state.directed_edges[eid] = (old_target, old_source) + self.state.outgoing[old_target].add((eid, old_source)) + self.state.incoming[old_source].add((eid, old_target)) + + # Move pebble counts: source loses one, first vertex in path gains one + self.state.free_pebbles[pebble_source] -= 1 + + # After all reversals, the vertex at the beginning of the + # search path gains a pebble + _first_src, first_tgt = self.state.directed_edges[edge_ids[0]] + self.state.free_pebbles[first_tgt] += 1 + + def get_dof(self) -> int: + """Total remaining DOF = sum of free pebbles. + + For a fully rigid assembly, this should be 6 (the trivial rigid + body motions of the whole assembly). Internal DOF = total - 6. + """ + return sum(self.state.free_pebbles.values()) + + def get_internal_dof(self) -> int: + """Internal (non-trivial) degrees of freedom.""" + return max(0, self.get_dof() - 6) + + def is_rigid(self) -> bool: + """Combinatorial rigidity check: rigid iff at most 6 pebbles remain.""" + return self.get_dof() <= self.L + + def get_redundant_count(self) -> int: + """Number of redundant (overconstrained) scalar constraints.""" + return len(self.state.redundant_edges) + + def classify_assembly(self, *, grounded: bool = False) -> str: + """Classify the assembly state. + + Args: + grounded: If True, the baseline trivial DOF is 0 (not 6), + because the ground body's 6 DOF were removed. + """ + total_dof = self.get_dof() + redundant = self.get_redundant_count() + baseline = 0 if grounded else self.L + + if redundant > 0 and total_dof > baseline: + return "mixed" # Both under and over-constrained regions + elif redundant > 0: + return "overconstrained" + elif total_dof > baseline: + return "underconstrained" + elif total_dof == baseline: + return "well-constrained" + else: + return "overconstrained" diff --git a/solver/datagen/types.py b/solver/datagen/types.py new file mode 100644 index 0000000..f7b3e84 --- /dev/null +++ b/solver/datagen/types.py @@ -0,0 +1,144 @@ +"""Shared data types for assembly constraint analysis. + +Types ported from the pebble-game synthetic data generator for reuse +across the solver package (data generation, training, inference). +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "ConstraintAnalysis", + "Joint", + "JointType", + "PebbleState", + "RigidBody", +] + + +# --------------------------------------------------------------------------- +# Joint definitions: each joint type removes a known number of DOF +# --------------------------------------------------------------------------- + + +class JointType(enum.Enum): + """Standard CAD joint types with their DOF-removal counts. + + Each joint between two 6-DOF rigid bodies removes a specific number + of relative degrees of freedom. In the body-bar-hinge multigraph + representation, each joint maps to a number of edges equal to the + DOF it removes. + + Values are ``(ordinal, dof_removed)`` tuples so that joint types + sharing the same DOF count remain distinct enum members. Use the + :attr:`dof` property to get the scalar constraint count. + """ + + FIXED = (0, 6) # Locks all relative motion + REVOLUTE = (1, 5) # Allows rotation about one axis + CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis + SLIDER = (3, 5) # Allows translation along one axis (prismatic) + BALL = (4, 3) # Allows rotation about a point (spherical) + PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane + SCREW = (6, 5) # Coupled rotation-translation (helical) + UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint) + PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints) + PERPENDICULAR = (9, 1) # Single angular constraint + DISTANCE = (10, 1) # Single scalar distance constraint + + @property + def dof(self) -> int: + """Number of scalar constraints (DOF removed) by this joint type.""" + return self.value[1] + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class RigidBody: + """A rigid body in the assembly with pose and geometry info.""" + + body_id: int + position: np.ndarray = field(default_factory=lambda: np.zeros(3)) + orientation: np.ndarray = field(default_factory=lambda: np.eye(3)) + + # Anchor points for joints, in local frame + # Populated when joints reference specific geometry + local_anchors: dict[str, np.ndarray] = field(default_factory=dict) + + +@dataclass +class Joint: + """A joint connecting two rigid bodies.""" + + joint_id: int + body_a: int # Index of first body + body_b: int # Index of second body + joint_type: JointType + + # Joint parameters in world frame + anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3)) + anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3)) + axis: np.ndarray = field( + default_factory=lambda: np.array([0.0, 0.0, 1.0]), + ) + + # For screw joints + pitch: float = 0.0 + + +@dataclass +class PebbleState: + """Tracks the state of the pebble game on the multigraph.""" + + # Number of free pebbles per body (vertex). Starts at 6. + free_pebbles: dict[int, int] = field(default_factory=dict) + + # Directed edges: edge_id -> (source_body, target_body) + # Edge is directed away from the body that "spent" a pebble. + directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict) + + # Track which edges are independent vs redundant + independent_edges: set[int] = field(default_factory=set) + redundant_edges: set[int] = field(default_factory=set) + + # Adjacency: body_id -> set of (edge_id, neighbor_body_id) + # Following directed edges *towards* a body (incoming edges) + incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + # Outgoing edges from a body + outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + +@dataclass +class ConstraintAnalysis: + """Results of analyzing an assembly's constraint system.""" + + # Pebble game (combinatorial) results + combinatorial_dof: int + combinatorial_internal_dof: int + combinatorial_redundant: int + combinatorial_classification: str + per_edge_results: list[dict[str, Any]] + + # Numerical (Jacobian) results + jacobian_rank: int + jacobian_nullity: int # = 6n - rank = total DOF + jacobian_internal_dof: int # = nullity - 6 + numerically_dependent: list[int] + + # Combined + geometric_degeneracies: int # = combinatorial_independent - jacobian_rank + is_rigid: bool + is_minimally_rigid: bool diff --git a/solver/datasets/__init__.py b/solver/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/evaluation/__init__.py b/solver/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/inference/__init__.py b/solver/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/models/__init__.py b/solver/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/training/__init__.py b/solver/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datagen/__init__.py b/tests/datagen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datagen/test_analysis.py b/tests/datagen/test_analysis.py new file mode 100644 index 0000000..8edf893 --- /dev/null +++ b/tests/datagen/test_analysis.py @@ -0,0 +1,240 @@ +"""Tests for solver.datagen.analysis -- combined analysis function.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _two_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + ] + + +def _triangle_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([1.0, 1.7, 0.0])), + ] + + +# --------------------------------------------------------------------------- +# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF) +# --------------------------------------------------------------------------- + + +class TestTwoBodiesRevolute: + """Demo scenario 1: two bodies connected by a revolute joint.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _two_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_internal_dof(self, result: ConstraintAnalysis) -> None: + assert result.jacobian_internal_dof == 1 + + def test_not_rigid(self, result: ConstraintAnalysis) -> None: + assert not result.is_rigid + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification == "underconstrained" + + def test_no_redundant(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_redundant == 0 + + +# --------------------------------------------------------------------------- +# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF) +# --------------------------------------------------------------------------- + + +class TestTwoBodiesFixed: + """Demo scenario 2: two bodies connected by a fixed joint.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _two_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_internal_dof(self, result: ConstraintAnalysis) -> None: + assert result.jacobian_internal_dof == 0 + + def test_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_rigid + + def test_minimally_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_minimally_rigid + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification == "well-constrained" + + +# --------------------------------------------------------------------------- +# Scenario 3: Triangle with revolute joints (overconstrained) +# --------------------------------------------------------------------------- + + +class TestTriangleRevolute: + """Demo scenario 3: triangle of 3 bodies + 3 revolute joints.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _triangle_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 1, + body_a=1, + body_b=2, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.5, 0.85, 0.0]), + anchor_b=np.array([1.5, 0.85, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 2, + body_a=2, + body_b=0, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([0.5, 0.85, 0.0]), + anchor_b=np.array([0.5, 0.85, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_has_redundant(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_redundant > 0 + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification in ("overconstrained", "mixed") + + def test_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_rigid + + def test_numerically_dependent(self, result: ConstraintAnalysis) -> None: + assert len(result.numerically_dependent) > 0 + + +# --------------------------------------------------------------------------- +# Scenario 4: Parallel revolute axes (geometric degeneracy) +# --------------------------------------------------------------------------- + + +class TestParallelRevoluteAxes: + """Demo scenario 4: parallel revolute axes create geometric degeneracies.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([4.0, 0.0, 0.0])), + ] + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 1, + body_a=1, + body_b=2, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([3.0, 0.0, 0.0]), + anchor_b=np.array([3.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None: + """Parallel axes produce at least one geometric degeneracy.""" + assert result.geometric_degeneracies > 0 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestNoJoints: + """Assembly with bodies but no joints.""" + + def test_all_dof_free(self) -> None: + bodies = _two_bodies() + result = analyze_assembly(bodies, [], ground_body=0) + # Body 1 is completely free (6 DOF), body 0 is grounded + assert result.jacobian_internal_dof > 0 + assert not result.is_rigid + + def test_ungrounded(self) -> None: + bodies = _two_bodies() + result = analyze_assembly(bodies, []) + assert result.combinatorial_classification == "underconstrained" + + +class TestReturnType: + """Verify the return object is a proper ConstraintAnalysis.""" + + def test_instance(self) -> None: + bodies = _two_bodies() + joints = [Joint(0, 0, 1, JointType.FIXED)] + result = analyze_assembly(bodies, joints) + assert isinstance(result, ConstraintAnalysis) + + def test_per_edge_results_populated(self) -> None: + bodies = _two_bodies() + joints = [Joint(0, 0, 1, JointType.REVOLUTE)] + result = analyze_assembly(bodies, joints) + assert len(result.per_edge_results) == 5 diff --git a/tests/datagen/test_dataset.py b/tests/datagen/test_dataset.py new file mode 100644 index 0000000..398184a --- /dev/null +++ b/tests/datagen/test_dataset.py @@ -0,0 +1,337 @@ +"""Tests for solver.datagen.dataset — dataset generation orchestration.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from typing import TYPE_CHECKING + +from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + _derive_shard_seed, + _parse_scalar, + parse_simple_yaml, +) + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any + + +# --------------------------------------------------------------------------- +# DatasetConfig +# --------------------------------------------------------------------------- + + +class TestDatasetConfig: + """DatasetConfig construction and defaults.""" + + def test_defaults(self) -> None: + cfg = DatasetConfig() + assert cfg.num_assemblies == 100_000 + assert cfg.seed == 42 + assert cfg.shard_size == 1000 + assert cfg.num_workers == 4 + + def test_from_dict_flat(self) -> None: + d: dict[str, Any] = {"num_assemblies": 500, "seed": 123} + cfg = DatasetConfig.from_dict(d) + assert cfg.num_assemblies == 500 + assert cfg.seed == 123 + + def test_from_dict_nested_body_count(self) -> None: + d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 3 + assert cfg.body_count_max == 20 + + def test_from_dict_flat_body_count(self) -> None: + d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 5 + assert cfg.body_count_max == 30 + + def test_from_dict_complexity_distribution(self) -> None: + d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}} + cfg = DatasetConfig.from_dict(d) + assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4} + + def test_from_dict_templates(self) -> None: + d: dict[str, Any] = {"templates": ["chain", "tree"]} + cfg = DatasetConfig.from_dict(d) + assert cfg.templates == ["chain", "tree"] + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +class TestParseScalar: + """_parse_scalar handles different value types.""" + + def test_int(self) -> None: + assert _parse_scalar("42") == 42 + + def test_float(self) -> None: + assert _parse_scalar("3.14") == 3.14 + + def test_bool_true(self) -> None: + assert _parse_scalar("true") is True + + def test_bool_false(self) -> None: + assert _parse_scalar("false") is False + + def test_string(self) -> None: + assert _parse_scalar("hello") == "hello" + + def test_inline_comment(self) -> None: + assert _parse_scalar("0.4 # some comment") == 0.4 + + +class TestParseSimpleYaml: + """parse_simple_yaml handles the synthetic.yaml format.""" + + def test_flat_scalars(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["name"] == "test" + assert result["num"] == 42 + assert result["ratio"] == 0.5 + + def test_nested_dict(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("body_count:\n min: 2\n max: 50\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["body_count"] == {"min": 2, "max": 50} + + def test_list(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["templates"] == ["chain", "tree", "loop"] + + def test_inline_comments(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("dist:\n simple: 0.4 # comment\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["dist"]["simple"] == 0.4 + + def test_synthetic_yaml(self) -> None: + """Parse the actual project config.""" + result = parse_simple_yaml("configs/dataset/synthetic.yaml") + assert result["name"] == "synthetic" + assert result["num_assemblies"] == 100000 + assert isinstance(result["complexity_distribution"], dict) + assert isinstance(result["templates"], list) + assert result["shard_size"] == 1000 + + +# --------------------------------------------------------------------------- +# Shard seed derivation +# --------------------------------------------------------------------------- + + +class TestShardSeedDerivation: + """_derive_shard_seed is deterministic and unique per shard.""" + + def test_deterministic(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 0) + assert s1 == s2 + + def test_different_shards(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 1) + assert s1 != s2 + + def test_different_global_seeds(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(99, 0) + assert s1 != s2 + + +# --------------------------------------------------------------------------- +# DatasetGenerator — small end-to-end tests +# --------------------------------------------------------------------------- + + +class TestDatasetGenerator: + """End-to-end tests with small datasets.""" + + def test_small_generation(self, tmp_path: Path) -> None: + """Generate 10 examples in a single shard.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + assert shards_dir.exists() + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 1 + + index_file = tmp_path / "output" / "index.json" + assert index_file.exists() + index = json.loads(index_file.read_text()) + assert index["total_assemblies"] == 10 + + stats_file = tmp_path / "output" / "stats.json" + assert stats_file.exists() + + def test_multi_shard(self, tmp_path: Path) -> None: + """Generate 20 examples across 2 shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 2 + + def test_resume_skips_completed(self, tmp_path: Path) -> None: + """Resume skips already-completed shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + # Record shard modification times + shards_dir = tmp_path / "output" / "shards" + mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")} + + # Remove stats (simulate incomplete) and re-run + (tmp_path / "output" / "stats.json").unlink() + + DatasetGenerator(cfg).run() + + # Shards should NOT have been regenerated + for p in shards_dir.glob("shard_*"): + assert p.stat().st_mtime == mtimes[p.name] + + # Stats should be regenerated + assert (tmp_path / "output" / "stats.json").exists() + + def test_checkpoint_removed(self, tmp_path: Path) -> None: + """Checkpoint file is cleaned up after completion.""" + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / "output"), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + checkpoint = tmp_path / "output" / ".checkpoint.json" + assert not checkpoint.exists() + + def test_stats_structure(self, tmp_path: Path) -> None: + """stats.json has expected top-level keys.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + stats = json.loads((tmp_path / "output" / "stats.json").read_text()) + assert stats["total_examples"] == 10 + assert "classification_distribution" in stats + assert "body_count_histogram" in stats + assert "joint_type_distribution" in stats + assert "dof_statistics" in stats + assert "geometric_degeneracy" in stats + assert "rigidity" in stats + + def test_index_structure(self, tmp_path: Path) -> None: + """index.json has expected format.""" + cfg = DatasetConfig( + num_assemblies=15, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + index = json.loads((tmp_path / "output" / "index.json").read_text()) + assert index["format_version"] == 1 + assert index["total_assemblies"] == 15 + assert index["total_shards"] == 2 + assert "shards" in index + for _name, info in index["shards"].items(): + assert "start_id" in info + assert "count" in info + + def test_deterministic_output(self, tmp_path: Path) -> None: + """Same seed produces same results.""" + for run_dir in ("run1", "run2"): + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / run_dir), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + s1 = json.loads((tmp_path / "run1" / "stats.json").read_text()) + s2 = json.loads((tmp_path / "run2" / "stats.json").read_text()) + assert s1["total_examples"] == s2["total_examples"] + assert s1["classification_distribution"] == s2["classification_distribution"] + + +# --------------------------------------------------------------------------- +# CLI integration test +# --------------------------------------------------------------------------- + + +class TestCLI: + """Run the script via subprocess.""" + + def test_argparse_mode(self, tmp_path: Path) -> None: + result = subprocess.run( + [ + sys.executable, + "scripts/generate_synthetic.py", + "--num-assemblies", + "5", + "--output-dir", + str(tmp_path / "cli_out"), + "--shard-size", + "5", + "--num-workers", + "1", + "--seed", + "42", + ], + capture_output=True, + text=True, + cwd="/home/developer", + timeout=120, + env={**os.environ, "PYTHONPATH": "/home/developer"}, + ) + assert result.returncode == 0, ( + f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) + assert (tmp_path / "cli_out" / "index.json").exists() + assert (tmp_path / "cli_out" / "stats.json").exists() diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py new file mode 100644 index 0000000..3b120fa --- /dev/null +++ b/tests/datagen/test_generator.py @@ -0,0 +1,682 @@ +"""Tests for solver.datagen.generator -- synthetic assembly generation.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np +import pytest + +from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator +from solver.datagen.types import JointType + +# --------------------------------------------------------------------------- +# Original generators (chain / rigid / overconstrained) +# --------------------------------------------------------------------------- + + +class TestChainAssembly: + """generate_chain_assembly produces valid underconstrained chains.""" + + def test_returns_three_tuple(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + bodies, joints, _analysis = gen.generate_chain_assembly(4) + assert len(bodies) == 4 + assert len(joints) == 3 + + def test_chain_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, _, analysis = gen.generate_chain_assembly(4) + assert analysis.combinatorial_classification == "underconstrained" + + def test_chain_body_ids(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + bodies, _, _ = gen.generate_chain_assembly(5) + ids = [b.body_id for b in bodies] + assert ids == [0, 1, 2, 3, 4] + + def test_chain_joint_connectivity(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, joints, _ = gen.generate_chain_assembly(4) + for i, j in enumerate(joints): + assert j.body_a == i + assert j.body_b == i + 1 + + def test_chain_custom_joint_type(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, joints, _ = gen.generate_chain_assembly( + 3, + joint_type=JointType.BALL, + ) + assert all(j.joint_type is JointType.BALL for j in joints) + + +class TestRigidAssembly: + """generate_rigid_assembly produces rigid assemblies.""" + + def test_rigid(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_rigid_assembly(4) + assert analysis.is_rigid + + def test_spanning_tree_structure(self) -> None: + """n bodies should have at least n-1 joints (spanning tree).""" + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_rigid_assembly(5) + assert len(joints) >= len(bodies) - 1 + + def test_deterministic(self) -> None: + """Same seed produces same results.""" + g1 = SyntheticAssemblyGenerator(seed=99) + g2 = SyntheticAssemblyGenerator(seed=99) + _, j1, a1 = g1.generate_rigid_assembly(4) + _, j2, a2 = g2.generate_rigid_assembly(4) + assert a1.jacobian_rank == a2.jacobian_rank + assert len(j1) == len(j2) + + +class TestOverconstrainedAssembly: + """generate_overconstrained_assembly adds redundant constraints.""" + + def test_has_redundant(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_overconstrained_assembly( + 4, + extra_joints=2, + ) + assert analysis.combinatorial_redundant > 0 + + def test_extra_joints_added(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints_base, _ = gen.generate_rigid_assembly(4) + + gen2 = SyntheticAssemblyGenerator(seed=42) + _, joints_over, _ = gen2.generate_overconstrained_assembly( + 4, + extra_joints=3, + ) + # Overconstrained has base joints + extra + assert len(joints_over) > len(joints_base) + + +# --------------------------------------------------------------------------- +# New topology generators +# --------------------------------------------------------------------------- + + +class TestTreeAssembly: + """generate_tree_assembly produces tree-structured assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_tree_assembly(6) + assert analysis.combinatorial_classification == "underconstrained" + + def test_branching_factor(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly( + 10, + branching_factor=2, + ) + assert len(bodies) == 10 + assert len(joints) == 9 + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED] + _, joints, _ = gen.generate_tree_assembly(10, joint_types=types) + used = {j.joint_type for j in joints} + # With 9 joints and 3 types, very likely to use at least 2 + assert len(used) >= 2 + + def test_single_joint_type(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_tree_assembly( + 5, + joint_types=JointType.BALL, + ) + assert all(j.joint_type is JointType.BALL for j in joints) + + def test_sequential_body_ids(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(7) + assert [b.body_id for b in bodies] == list(range(7)) + + +class TestLoopAssembly: + """generate_loop_assembly produces closed-loop assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_loop_assembly(5) + assert len(bodies) == 5 + assert len(joints) == 5 # n joints for n bodies + + def test_has_redundancy(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_loop_assembly(5) + assert analysis.combinatorial_redundant > 0 + + def test_wrap_around_connectivity(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly(4) + edges = {(j.body_a, j.body_b) for j in joints} + assert (0, 1) in edges + assert (1, 2) in edges + assert (2, 3) in edges + assert (3, 0) in edges # wrap-around + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 3"): + gen.generate_loop_assembly(2) + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.FIXED] + _, joints, _ = gen.generate_loop_assembly(8, joint_types=types) + used = {j.joint_type for j in joints} + assert len(used) >= 2 + + +class TestStarAssembly: + """generate_star_assembly produces hub-and-spoke assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_star_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_all_joints_connect_to_hub(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly(6) + for j in joints: + assert j.body_a == 0 or j.body_b == 0 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_star_assembly(5) + assert analysis.combinatorial_classification == "underconstrained" + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 2"): + gen.generate_star_assembly(1) + + def test_hub_at_origin(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(4) + np.testing.assert_array_equal(bodies[0].position, np.zeros(3)) + + +class TestMixedAssembly: + """generate_mixed_assembly produces tree+loop hybrid assemblies.""" + + def test_more_joints_than_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_mixed_assembly( + 8, + edge_density=0.3, + ) + assert len(joints) > len(bodies) - 1 + + def test_density_zero_is_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=0.0, + ) + assert len(joints) == 4 # spanning tree only + + def test_density_validation(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=1.5) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=-0.1) + + def test_no_duplicate_edges(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5) + edges = [frozenset([j.body_a, j.body_b]) for j in joints] + assert len(edges) == len(set(edges)) + + def test_high_density(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=1.0, + ) + # Fully connected: 5*(5-1)/2 = 10 edges + assert len(joints) == 10 + + +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + + +class TestAxisStrategy: + """Axis sampling strategies produce valid unit vectors.""" + + def test_cardinal_axis_from_six(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axes = {tuple(gen._cardinal_axis()) for _ in range(200)} + expected = { + (1, 0, 0), + (-1, 0, 0), + (0, 1, 0), + (0, -1, 0), + (0, 0, 1), + (0, 0, -1), + } + assert axes == expected + + def test_random_axis_unit_norm(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + for _ in range(50): + axis = gen._sample_axis("random") + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + + def test_near_parallel_close_to_base(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + base = np.array([0.0, 0.0, 1.0]) + for _ in range(50): + axis = gen._near_parallel_axis(base) + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + assert np.dot(axis, base) > 0.95 + + def test_sample_axis_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("cardinal") + cardinals = [ + np.array(v, dtype=float) + for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] + ] + assert any(np.allclose(axis, c) for c in cardinals) + + def test_sample_axis_near_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("near_parallel") + z = np.array([0.0, 0.0, 1.0]) + assert np.dot(axis, z) > 0.95 + + +# --------------------------------------------------------------------------- +# Geometric diversity: orientations +# --------------------------------------------------------------------------- + + +class TestRandomOrientations: + """Bodies should have non-identity orientations.""" + + def test_bodies_have_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(5) + non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3))) + assert non_identity >= 3 + + def test_orientations_are_valid_rotations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(6) + for b in bodies: + r = b.orientation + # R^T R == I + np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10) + # det(R) == 1 + assert abs(np.linalg.det(r) - 1.0) < 1e-10 + + def test_all_generators_set_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + # Chain + bodies, _, _ = gen.generate_chain_assembly(3) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Loop + bodies, _, _ = gen.generate_loop_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Mixed + bodies, _, _ = gen.generate_mixed_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + + +# --------------------------------------------------------------------------- +# Geometric diversity: grounded parameter +# --------------------------------------------------------------------------- + + +class TestGroundedParameter: + """Grounded parameter controls ground_body in analysis.""" + + def test_chain_grounded_default(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly(4) + assert analysis.combinatorial_dof >= 0 + + def test_chain_floating(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly( + 4, + grounded=False, + ) + # Floating: 6 trivial DOF not subtracted by ground + assert analysis.combinatorial_dof >= 6 + + def test_floating_vs_grounded_dof_difference(self) -> None: + gen1 = SyntheticAssemblyGenerator(seed=42) + _, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True) + gen2 = SyntheticAssemblyGenerator(seed=42) + _, _, a_floating = gen2.generate_chain_assembly(4, grounded=False) + # Floating should have higher DOF due to missing ground constraint + assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof + + @pytest.mark.parametrize( + "gen_method", + [ + "generate_chain_assembly", + "generate_rigid_assembly", + "generate_tree_assembly", + "generate_loop_assembly", + "generate_star_assembly", + "generate_mixed_assembly", + ], + ) + def test_all_generators_accept_grounded( + self, + gen_method: str, + ) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + method = getattr(gen, gen_method) + n = 4 + # Should not raise + if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"): + method(n, grounded=False) + else: + method(n, grounded=False) + + +# --------------------------------------------------------------------------- +# Geometric diversity: parallel axis injection +# --------------------------------------------------------------------------- + + +class TestParallelAxisInjection: + """parallel_axis_prob causes shared axis direction.""" + + def test_parallel_axes_similar(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + # Near-parallel: |dot| close to 1 + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_zero_prob_no_forced_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=0.0, + ) + base = joints[0].axis + dots = [abs(np.dot(j.axis, base)) for j in joints[1:]] + # With 5 random axes, extremely unlikely all are parallel + assert min(dots) < 0.95 + + def test_parallel_on_loop(self) -> None: + """Parallel axes on a loop assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_parallel_on_star(self) -> None: + """Parallel axes on a star assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + +# --------------------------------------------------------------------------- +# Complexity tiers +# --------------------------------------------------------------------------- + + +class TestComplexityTiers: + """Complexity tier parameter on batch generation.""" + + def test_simple_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="simple") + lo, hi = COMPLEXITY_RANGES["simple"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_medium_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="medium") + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_complex_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(3, complexity_tier="complex") + lo, hi = COMPLEXITY_RANGES["complex"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_tier_overrides_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + n_bodies_range=(2, 3), + complexity_tier="medium", + ) + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + +# --------------------------------------------------------------------------- +# Training batch +# --------------------------------------------------------------------------- + + +class TestTrainingBatch: + """generate_training_batch produces well-structured examples.""" + + EXPECTED_KEYS: ClassVar[set[str]] = { + "example_id", + "generator_type", + "grounded", + "n_bodies", + "n_joints", + "body_positions", + "body_orientations", + "joints", + "joint_labels", + "labels", + "assembly_classification", + "is_rigid", + "is_minimally_rigid", + "internal_dof", + "geometric_degeneracies", + } + + VALID_GEN_TYPES: ClassVar[set[str]] = { + "chain", + "rigid", + "overconstrained", + "tree", + "loop", + "star", + "mixed", + } + + def test_batch_size(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20) + assert len(batch) == 20 + + def test_example_keys(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + assert set(ex.keys()) == self.EXPECTED_KEYS + + def test_example_ids_sequential(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(15) + assert [ex["example_id"] for ex in batch] == list(range(15)) + + def test_generator_type_valid(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(50) + for ex in batch: + assert ex["generator_type"] in self.VALID_GEN_TYPES + + def test_generator_type_diversity(self) -> None: + """100-sample batch should use at least 5 of 7 generator types.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100) + types = {ex["generator_type"] for ex in batch} + assert len(types) >= 5 + + def test_default_body_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) + for ex in batch: + # default (3, 8), but loop/star may clamp + assert 2 <= ex["n_bodies"] <= 7 + + def test_joint_label_consistency(self) -> None: + """independent + redundant == total for every joint.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) + for ex in batch: + for label in ex["joint_labels"].values(): + total = label["independent_constraints"] + label["redundant_constraints"] + assert total == label["total_constraints"] + + def test_body_orientations_present(self) -> None: + """Each example includes body_orientations as 3x3 lists.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + orients = ex["body_orientations"] + assert len(orients) == ex["n_bodies"] + for o in orients: + assert len(o) == 3 + assert len(o[0]) == 3 + + def test_labels_structure(self) -> None: + """Each example has labels dict with expected sub-keys.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + labels = ex["labels"] + assert "per_constraint" in labels + assert "per_joint" in labels + assert "per_body" in labels + assert "assembly" in labels + + def test_grounded_field_present(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + assert isinstance(ex["grounded"], bool) + + +# --------------------------------------------------------------------------- +# Batch grounded ratio +# --------------------------------------------------------------------------- + + +class TestBatchGroundedRatio: + """grounded_ratio controls the mix in batch generation.""" + + def test_all_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=1.0) + assert all(ex["grounded"] for ex in batch) + + def test_none_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=0.0) + assert not any(ex["grounded"] for ex in batch) + + def test_mixed_ratio(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100, grounded_ratio=0.5) + grounded_count = sum(1 for ex in batch if ex["grounded"]) + # With 100 samples and p=0.5, should be roughly 50 +/- 20 + assert 20 < grounded_count < 80 + + def test_batch_axis_strategy_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + axis_strategy="cardinal", + ) + assert len(batch) == 10 + + def test_batch_parallel_axis_prob(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + parallel_axis_prob=0.5, + ) + assert len(batch) == 10 + + +# --------------------------------------------------------------------------- +# Seed reproducibility +# --------------------------------------------------------------------------- + + +class TestSeedReproducibility: + """Different seeds produce different results.""" + + def test_different_seeds_differ(self) -> None: + g1 = SyntheticAssemblyGenerator(seed=1) + g2 = SyntheticAssemblyGenerator(seed=2) + b1 = g1.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) + b2 = g2.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) + c1 = [ex["assembly_classification"] for ex in b1] + c2 = [ex["assembly_classification"] for ex in b2] + r1 = [ex["is_rigid"] for ex in b1] + r2 = [ex["is_rigid"] for ex in b2] + assert c1 != c2 or r1 != r2 + + def test_same_seed_identical(self) -> None: + g1 = SyntheticAssemblyGenerator(seed=123) + g2 = SyntheticAssemblyGenerator(seed=123) + b1, j1, _ = g1.generate_tree_assembly(5) + b2, j2, _ = g2.generate_tree_assembly(5) + for a, b in zip(b1, b2, strict=True): + np.testing.assert_array_almost_equal(a.position, b.position) + assert len(j1) == len(j2) diff --git a/tests/datagen/test_jacobian.py b/tests/datagen/test_jacobian.py new file mode 100644 index 0000000..864331d --- /dev/null +++ b/tests/datagen/test_jacobian.py @@ -0,0 +1,267 @@ +"""Tests for solver.datagen.jacobian -- Jacobian rank verification.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.types import Joint, JointType, RigidBody + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _two_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + ] + + +def _three_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([4.0, 0.0, 0.0])), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestJacobianShape: + """Verify Jacobian matrix dimensions for each joint type.""" + + @pytest.mark.parametrize( + "joint_type,expected_rows", + [ + (JointType.FIXED, 6), + (JointType.REVOLUTE, 5), + (JointType.CYLINDRICAL, 4), + (JointType.SLIDER, 5), + (JointType.BALL, 3), + (JointType.PLANAR, 3), + (JointType.SCREW, 5), + (JointType.UNIVERSAL, 4), + (JointType.PARALLEL, 3), + (JointType.PERPENDICULAR, 1), + (JointType.DISTANCE, 1), + ], + ) + def test_row_count(self, joint_type: JointType, expected_rows: int) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + joint = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=joint_type, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + n_added = v.add_joint_constraints(joint) + assert n_added == expected_rows + + j = v.get_jacobian() + assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols + + +class TestNumericalRank: + """Numerical rank checks for known configurations.""" + + def test_fixed_joint_rank_six(self) -> None: + """Fixed joint between 2 bodies: rank = 6.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 6 + + def test_revolute_joint_rank_five(self) -> None: + """Revolute joint between 2 bodies: rank = 5.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 5 + + def test_ball_joint_rank_three(self) -> None: + """Ball joint between 2 bodies: rank = 3.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.BALL, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 3 + + def test_empty_jacobian_rank_zero(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + assert v.numerical_rank() == 0 + + +class TestParallelAxesDegeneracy: + """Parallel revolute axes create geometric dependencies.""" + + def _four_body_loop(self) -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([2.0, 2.0, 0.0])), + RigidBody(3, position=np.array([0.0, 2.0, 0.0])), + ] + + def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]: + pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])] + return [ + Joint( + joint_id=i, + body_a=a, + body_b=b, + joint_type=JointType.REVOLUTE, + anchor_a=np.array(anc, dtype=float), + anchor_b=np.array(anc, dtype=float), + axis=axes[i], + ) + for i, (a, b, anc) in enumerate(pairs) + ] + + def test_parallel_has_lower_rank(self) -> None: + """4-body closed loop: all-parallel revolute axes produce lower + Jacobian rank than mixed axes due to geometric dependency.""" + bodies = self._four_body_loop() + z_axis = np.array([0.0, 0.0, 1.0]) + + # All axes parallel to Z + v_par = JacobianVerifier(bodies) + for j in self._loop_joints([z_axis] * 4): + v_par.add_joint_constraints(j) + rank_par = v_par.numerical_rank() + + # Mixed axes + mixed = [ + np.array([0.0, 0.0, 1.0]), + np.array([0.0, 1.0, 0.0]), + np.array([0.0, 0.0, 1.0]), + np.array([1.0, 0.0, 0.0]), + ] + v_mix = JacobianVerifier(bodies) + for j in self._loop_joints(mixed): + v_mix.add_joint_constraints(j) + rank_mix = v_mix.numerical_rank() + + assert rank_par < rank_mix + + +class TestFindDependencies: + """Dependency detection.""" + + def test_fixed_joint_no_dependencies(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.find_dependencies() == [] + + def test_duplicate_fixed_has_dependencies(self) -> None: + """Two fixed joints on same pair: second is fully dependent.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + for jid in range(2): + v.add_joint_constraints( + Joint( + joint_id=jid, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + ) + deps = v.find_dependencies() + assert len(deps) == 6 # Second fixed joint entirely redundant + + def test_empty_no_dependencies(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + assert v.find_dependencies() == [] + + +class TestRowLabels: + """Row label metadata.""" + + def test_labels_match_rows(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=7, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + v.add_joint_constraints(j) + assert len(v.row_labels) == 5 + assert all(lab["joint_id"] == 7 for lab in v.row_labels) + + +class TestPerpendicularPair: + """Internal _perpendicular_pair utility.""" + + @pytest.mark.parametrize( + "axis", + [ + np.array([1.0, 0.0, 0.0]), + np.array([0.0, 1.0, 0.0]), + np.array([0.0, 0.0, 1.0]), + np.array([1.0, 1.0, 1.0]) / np.sqrt(3), + ], + ) + def test_orthonormal(self, axis: np.ndarray) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + t1, t2 = v._perpendicular_pair(axis) + + # All unit length + np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12) + np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12) + + # Mutually perpendicular + np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12) + np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12) + np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12) diff --git a/tests/datagen/test_labeling.py b/tests/datagen/test_labeling.py new file mode 100644 index 0000000..a81e2e3 --- /dev/null +++ b/tests/datagen/test_labeling.py @@ -0,0 +1,346 @@ +"""Tests for solver.datagen.labeling -- ground truth labeling pipeline.""" + +from __future__ import annotations + +import json + +import numpy as np + +from solver.datagen.labeling import ( + label_assembly, +) +from solver.datagen.types import Joint, JointType, RigidBody + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]: + return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)] + + +def _make_joint( + jid: int, + a: int, + b: int, + jtype: JointType, + axis: tuple[float, ...] = (0.0, 0.0, 1.0), +) -> Joint: + return Joint( + joint_id=jid, + body_a=a, + body_b=b, + joint_type=jtype, + anchor_a=np.zeros(3), + anchor_b=np.zeros(3), + axis=np.array(axis), + ) + + +# --------------------------------------------------------------------------- +# Per-constraint labels +# --------------------------------------------------------------------------- + + +class TestConstraintLabels: + """Per-constraint labels combine pebble game and Jacobian results.""" + + def test_fixed_joint_all_independent(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 6 + for cl in labels.per_constraint: + assert cl.pebble_independent is True + assert cl.jacobian_independent is True + + def test_revolute_joint_all_independent(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 5 + for cl in labels.per_constraint: + assert cl.pebble_independent is True + assert cl.jacobian_independent is True + + def test_chain_constraint_count(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 10 # 5 + 5 + + def test_constraint_joint_ids(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.BALL), + ] + labels = label_assembly(bodies, joints, ground_body=0) + j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0] + j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1] + assert len(j0_constraints) == 5 # revolute + assert len(j1_constraints) == 3 # ball + + def test_overconstrained_has_pebble_redundant(self) -> None: + """Triangle with revolute joints: some constraints redundant.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent) + assert pebble_redundant > 0 + + +# --------------------------------------------------------------------------- +# Per-joint labels +# --------------------------------------------------------------------------- + + +class TestJointLabels: + """Per-joint aggregated labels.""" + + def test_fixed_joint_counts(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_joint) == 1 + jl = labels.per_joint[0] + assert jl.joint_id == 0 + assert jl.independent_count == 6 + assert jl.redundant_count == 0 + assert jl.total == 6 + + def test_overconstrained_has_redundant_joints(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + total_redundant = sum(jl.redundant_count for jl in labels.per_joint) + assert total_redundant > 0 + + def test_joint_total_equals_dof(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.BALL)] + labels = label_assembly(bodies, joints, ground_body=0) + jl = labels.per_joint[0] + assert jl.total == 3 # ball has 3 DOF + + +# --------------------------------------------------------------------------- +# Per-body DOF labels +# --------------------------------------------------------------------------- + + +class TestBodyDofLabels: + """Per-body DOF signatures from nullspace projection.""" + + def test_fixed_joint_grounded_both_zero(self) -> None: + """Two bodies + fixed joint + grounded: both fully constrained.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + for bl in labels.per_body: + assert bl.translational_dof == 0 + assert bl.rotational_dof == 0 + + def test_revolute_has_rotational_dof(self) -> None: + """Two bodies + revolute + grounded: body 1 has rotational DOF.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + b1 = next(b for b in labels.per_body if b.body_id == 1) + # Revolute allows 1 rotation DOF + assert b1.rotational_dof >= 1 + + def test_dof_bounds(self) -> None: + """All DOF values should be in [0, 3].""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + for bl in labels.per_body: + assert 0 <= bl.translational_dof <= 3 + assert 0 <= bl.rotational_dof <= 3 + + def test_floating_more_dof_than_grounded(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + grounded = label_assembly(bodies, joints, ground_body=0) + floating = label_assembly(bodies, joints, ground_body=None) + g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body) + f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body) + assert f_total > g_total + + def test_grounded_body_zero_dof(self) -> None: + """The grounded body should have 0 DOF.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + b0 = next(b for b in labels.per_body if b.body_id == 0) + assert b0.translational_dof == 0 + assert b0.rotational_dof == 0 + + def test_body_count_matches(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.BALL), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_body) == 3 + + +# --------------------------------------------------------------------------- +# Assembly label +# --------------------------------------------------------------------------- + + +class TestAssemblyLabel: + """Assembly-wide summary labels.""" + + def test_underconstrained_chain(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.classification == "underconstrained" + assert labels.assembly.is_rigid is False + + def test_well_constrained(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.classification == "well-constrained" + assert labels.assembly.is_rigid is True + assert labels.assembly.is_minimally_rigid is True + + def test_overconstrained(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.redundant_count > 0 + + def test_has_degeneracy_with_parallel_axes(self) -> None: + """Parallel revolute axes in a loop create geometric degeneracy.""" + z_axis = (0.0, 0.0, 1.0) + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis), + _make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis), + _make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis), + _make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.has_degeneracy is True + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +class TestToDict: + """to_dict produces JSON-serializable output.""" + + def test_top_level_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + assert set(d.keys()) == { + "per_constraint", + "per_joint", + "per_body", + "assembly", + } + + def test_per_constraint_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + for item in d["per_constraint"]: + assert set(item.keys()) == { + "joint_id", + "constraint_idx", + "pebble_independent", + "jacobian_independent", + } + + def test_assembly_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + assert set(d["assembly"].keys()) == { + "classification", + "total_dof", + "redundant_count", + "is_rigid", + "is_minimally_rigid", + "has_degeneracy", + } + + def test_json_serializable(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + # Should not raise + serialized = json.dumps(d) + assert isinstance(serialized, str) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestLabelAssemblyEdgeCases: + """Edge cases for label_assembly.""" + + def test_no_joints(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + labels = label_assembly(bodies, [], ground_body=0) + assert len(labels.per_constraint) == 0 + assert len(labels.per_joint) == 0 + assert labels.assembly.classification == "underconstrained" + # Non-ground body should be fully free + b1 = next(b for b in labels.per_body if b.body_id == 1) + assert b1.translational_dof == 3 + assert b1.rotational_dof == 3 + + def test_no_joints_floating(self) -> None: + bodies = _make_bodies((0, 0, 0)) + labels = label_assembly(bodies, [], ground_body=None) + assert len(labels.per_body) == 1 + assert labels.per_body[0].translational_dof == 3 + assert labels.per_body[0].rotational_dof == 3 + + def test_analysis_embedded(self) -> None: + """AssemblyLabels.analysis should be a valid ConstraintAnalysis.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + analysis = labels.analysis + assert hasattr(analysis, "combinatorial_classification") + assert hasattr(analysis, "jacobian_rank") + assert hasattr(analysis, "is_rigid") diff --git a/tests/datagen/test_pebble_game.py b/tests/datagen/test_pebble_game.py new file mode 100644 index 0000000..da71eab --- /dev/null +++ b/tests/datagen/test_pebble_game.py @@ -0,0 +1,206 @@ +"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game.""" + +from __future__ import annotations + +import itertools + +import numpy as np +import pytest + +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import Joint, JointType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint: + """Shorthand for a revolute joint between bodies *a* and *b*.""" + if axis is None: + axis = np.array([0.0, 0.0, 1.0]) + return Joint( + joint_id=jid, + body_a=a, + body_b=b, + joint_type=JointType.REVOLUTE, + axis=axis, + ) + + +def _fixed(jid: int, a: int, b: int) -> Joint: + return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestAddBody: + """Body registration basics.""" + + def test_single_body_six_pebbles(self) -> None: + pg = PebbleGame3D() + pg.add_body(0) + assert pg.state.free_pebbles[0] == 6 + + def test_duplicate_body_no_op(self) -> None: + pg = PebbleGame3D() + pg.add_body(0) + pg.add_body(0) + assert pg.state.free_pebbles[0] == 6 + + def test_multiple_bodies(self) -> None: + pg = PebbleGame3D() + for i in range(5): + pg.add_body(i) + assert pg.get_dof() == 30 # 5 * 6 + + +class TestAddJoint: + """Joint insertion and DOF accounting.""" + + def test_revolute_removes_five_dof(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + assert len(results) == 5 # 5 scalar constraints + assert all(r["independent"] for r in results) + # 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles + assert pg.get_dof() == 7 + + def test_fixed_removes_six_dof(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_fixed(0, 0, 1)) + assert len(results) == 6 + assert all(r["independent"] for r in results) + assert pg.get_dof() == 6 + + def test_ball_removes_three_dof(self) -> None: + pg = PebbleGame3D() + j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL) + results = pg.add_joint(j) + assert len(results) == 3 + assert all(r["independent"] for r in results) + assert pg.get_dof() == 9 + + +class TestTwoBodiesRevolute: + """Two bodies connected by a revolute -- demo scenario 1.""" + + def test_internal_dof(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + # Total DOF = 7, internal = 7 - 6 = 1 + assert pg.get_internal_dof() == 1 + + def test_not_rigid(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + assert not pg.is_rigid() + + def test_classification(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + assert pg.classify_assembly() == "underconstrained" + + +class TestTwoBodiesFixed: + """Two bodies + fixed joint -- demo scenario 2.""" + + def test_zero_internal_dof(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.get_internal_dof() == 0 + + def test_rigid(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.is_rigid() + + def test_well_constrained(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.classify_assembly() == "well-constrained" + + +class TestTriangleRevolute: + """Triangle of 3 bodies with revolute joints -- demo scenario 3.""" + + @pytest.fixture() + def pg(self) -> PebbleGame3D: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + pg.add_joint(_revolute(1, 1, 2)) + pg.add_joint(_revolute(2, 2, 0)) + return pg + + def test_has_redundant_edges(self, pg: PebbleGame3D) -> None: + assert pg.get_redundant_count() > 0 + + def test_classification_overconstrained(self, pg: PebbleGame3D) -> None: + # 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed) + assert pg.classify_assembly() in ("overconstrained", "mixed") + + def test_rigid(self, pg: PebbleGame3D) -> None: + assert pg.is_rigid() + + +class TestChainNotRigid: + """A serial chain of 4 bodies with revolute joints is never rigid.""" + + def test_chain_underconstrained(self) -> None: + pg = PebbleGame3D() + for i in range(3): + pg.add_joint(_revolute(i, i, i + 1)) + assert not pg.is_rigid() + assert pg.classify_assembly() == "underconstrained" + + def test_chain_internal_dof(self) -> None: + pg = PebbleGame3D() + for i in range(3): + pg.add_joint(_revolute(i, i, i + 1)) + # 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3 + assert pg.get_internal_dof() == 3 + + +class TestEdgeResults: + """Result dicts returned by add_joint.""" + + def test_result_keys(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"} + for r in results: + assert set(r.keys()) == expected_keys + + def test_edge_ids_sequential(self) -> None: + pg = PebbleGame3D() + r1 = pg.add_joint(_revolute(0, 0, 1)) + r2 = pg.add_joint(_revolute(1, 1, 2)) + all_ids = [r["edge_id"] for r in r1 + r2] + assert all_ids == list(range(10)) + + def test_dof_remaining_monotonic(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + dofs = [r["dof_remaining"] for r in results] + # Should be non-increasing (each independent edge removes a pebble) + for a, b in itertools.pairwise(dofs): + assert a >= b + + +class TestGroundedClassification: + """classify_assembly with grounded=True.""" + + def test_grounded_baseline_zero(self) -> None: + """With grounded=True the baseline is 0 (not 6).""" + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + # Ungrounded: well-constrained (6 pebbles = baseline 6) + assert pg.classify_assembly(grounded=False) == "well-constrained" + # Grounded: the 6 remaining pebbles on body 1 exceed baseline 0, + # so the raw pebble game (without a virtual ground body) sees this + # as underconstrained. The analysis function handles this properly + # by adding a virtual ground body. + assert pg.classify_assembly(grounded=True) == "underconstrained" diff --git a/tests/datagen/test_types.py b/tests/datagen/test_types.py new file mode 100644 index 0000000..7667243 --- /dev/null +++ b/tests/datagen/test_types.py @@ -0,0 +1,163 @@ +"""Tests for solver.datagen.types -- shared data types.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np +import pytest + +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + PebbleState, + RigidBody, +) + + +class TestJointType: + """JointType enum construction and DOF values.""" + + EXPECTED_DOF: ClassVar[dict[str, int]] = { + "FIXED": 6, + "REVOLUTE": 5, + "CYLINDRICAL": 4, + "SLIDER": 5, + "BALL": 3, + "PLANAR": 3, + "SCREW": 5, + "UNIVERSAL": 4, + "PARALLEL": 3, + "PERPENDICULAR": 1, + "DISTANCE": 1, + } + + def test_member_count(self) -> None: + assert len(JointType) == 11 + + @pytest.mark.parametrize("name,dof", EXPECTED_DOF.items()) + def test_dof_values(self, name: str, dof: int) -> None: + assert JointType[name].dof == dof + + def test_access_by_name(self) -> None: + assert JointType["REVOLUTE"] is JointType.REVOLUTE + + def test_value_is_tuple(self) -> None: + assert JointType.REVOLUTE.value == (1, 5) + assert JointType.REVOLUTE.dof == 5 + + +class TestRigidBody: + """RigidBody dataclass defaults and construction.""" + + def test_defaults(self) -> None: + body = RigidBody(body_id=0) + np.testing.assert_array_equal(body.position, np.zeros(3)) + np.testing.assert_array_equal(body.orientation, np.eye(3)) + assert body.local_anchors == {} + + def test_custom_position(self) -> None: + pos = np.array([1.0, 2.0, 3.0]) + body = RigidBody(body_id=7, position=pos) + np.testing.assert_array_equal(body.position, pos) + assert body.body_id == 7 + + def test_local_anchors_mutable(self) -> None: + body = RigidBody(body_id=0) + body.local_anchors["top"] = np.array([0.0, 0.0, 1.0]) + assert "top" in body.local_anchors + + def test_default_factory_isolation(self) -> None: + """Each instance gets its own default containers.""" + b1 = RigidBody(body_id=0) + b2 = RigidBody(body_id=1) + b1.local_anchors["x"] = np.zeros(3) + assert "x" not in b2.local_anchors + + +class TestJoint: + """Joint dataclass defaults and construction.""" + + def test_defaults(self) -> None: + j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE) + np.testing.assert_array_equal(j.anchor_a, np.zeros(3)) + np.testing.assert_array_equal(j.anchor_b, np.zeros(3)) + np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0])) + assert j.pitch == 0.0 + + def test_full_construction(self) -> None: + j = Joint( + joint_id=5, + body_a=2, + body_b=3, + joint_type=JointType.SCREW, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([2.0, 0.0, 0.0]), + axis=np.array([1.0, 0.0, 0.0]), + pitch=0.5, + ) + assert j.joint_id == 5 + assert j.joint_type is JointType.SCREW + assert j.pitch == 0.5 + + +class TestPebbleState: + """PebbleState dataclass defaults.""" + + def test_defaults(self) -> None: + s = PebbleState() + assert s.free_pebbles == {} + assert s.directed_edges == {} + assert s.independent_edges == set() + assert s.redundant_edges == set() + assert s.incoming == {} + assert s.outgoing == {} + + def test_default_factory_isolation(self) -> None: + s1 = PebbleState() + s2 = PebbleState() + s1.free_pebbles[0] = 6 + assert 0 not in s2.free_pebbles + + +class TestConstraintAnalysis: + """ConstraintAnalysis dataclass construction.""" + + def test_construction(self) -> None: + ca = ConstraintAnalysis( + combinatorial_dof=6, + combinatorial_internal_dof=0, + combinatorial_redundant=0, + combinatorial_classification="well-constrained", + per_edge_results=[], + jacobian_rank=6, + jacobian_nullity=0, + jacobian_internal_dof=0, + numerically_dependent=[], + geometric_degeneracies=0, + is_rigid=True, + is_minimally_rigid=True, + ) + assert ca.is_rigid is True + assert ca.is_minimally_rigid is True + assert ca.combinatorial_classification == "well-constrained" + + def test_per_edge_results_typing(self) -> None: + """per_edge_results accepts list[dict[str, Any]].""" + ca = ConstraintAnalysis( + combinatorial_dof=7, + combinatorial_internal_dof=1, + combinatorial_redundant=0, + combinatorial_classification="underconstrained", + per_edge_results=[{"edge_id": 0, "independent": True}], + jacobian_rank=5, + jacobian_nullity=1, + jacobian_internal_dof=1, + numerically_dependent=[], + geometric_degeneracies=0, + is_rigid=False, + is_minimally_rigid=False, + ) + assert len(ca.per_edge_results) == 1 + assert ca.per_edge_results[0]["edge_id"] == 0