From f61d0054000ccbc4ab7176e9519212ec6e3b8a14 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:09:37 -0600 Subject: [PATCH 01/12] first commit --- README.md | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 README.md diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 From 363b49281b91d2ef78e34be0b79861b781696c08 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:26:38 -0600 Subject: [PATCH 02/12] build: phase 0 infrastructure setup - Project structure: solver/, freecad/, export/, configs/, scripts/, tests/, docs/ - pyproject.toml with dependency groups: core, train, freecad, dev - Hydra configs: dataset (synthetic, fusion360), model (baseline, gat), training (pretrain, finetune), export (production) - Dockerfile with CUDA+PyG GPU and CPU-only targets - docker-compose.yml for train, test, data-gen services - Makefile with targets: train, test, lint, format, type-check, data-gen, export, check - Pre-commit hooks: ruff, mypy, conventional commits - Gitea Actions CI: lint, type-check, test on push/PR - README with setup and usage instructions --- .gitea/workflows/ci.yaml | 65 +++++++++++++++++++++++ .gitignore | 48 +++++++++++++++++ .pre-commit-config.yaml | 23 ++++++++ Dockerfile | 61 +++++++++++++++++++++ Makefile | 48 +++++++++++++++++ README.md | 71 +++++++++++++++++++++++++ configs/dataset/fusion360.yaml | 12 +++++ configs/dataset/synthetic.yaml | 24 +++++++++ configs/export/production.yaml | 25 +++++++++ configs/model/baseline.yaml | 24 +++++++++ configs/model/gat.yaml | 28 ++++++++++ configs/training/finetune.yaml | 45 ++++++++++++++++ configs/training/pretrain.yaml | 42 +++++++++++++++ data/fusion360/.gitkeep | 0 data/processed/.gitkeep | 0 data/splits/.gitkeep | 0 data/synthetic/.gitkeep | 0 docker-compose.yml | 39 ++++++++++++++ docs/.gitkeep | 0 export/.gitkeep | 0 freecad/__init__.py | 0 freecad/bridge/__init__.py | 0 freecad/tests/__init__.py | 0 freecad/workbench/__init__.py | 0 pyproject.toml | 97 ++++++++++++++++++++++++++++++++++ solver/__init__.py | 0 solver/datagen/__init__.py | 0 solver/datasets/__init__.py | 0 solver/evaluation/__init__.py | 0 solver/inference/__init__.py | 0 solver/models/__init__.py | 0 solver/training/__init__.py | 0 tests/__init__.py | 0 33 files changed, 652 insertions(+) create mode 100644 .gitea/workflows/ci.yaml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 Dockerfile create mode 100644 Makefile create mode 100644 configs/dataset/fusion360.yaml create mode 100644 configs/dataset/synthetic.yaml create mode 100644 configs/export/production.yaml create mode 100644 configs/model/baseline.yaml create mode 100644 configs/model/gat.yaml create mode 100644 configs/training/finetune.yaml create mode 100644 configs/training/pretrain.yaml create mode 100644 data/fusion360/.gitkeep create mode 100644 data/processed/.gitkeep create mode 100644 data/splits/.gitkeep create mode 100644 data/synthetic/.gitkeep create mode 100644 docker-compose.yml create mode 100644 docs/.gitkeep create mode 100644 export/.gitkeep create mode 100644 freecad/__init__.py create mode 100644 freecad/bridge/__init__.py create mode 100644 freecad/tests/__init__.py create mode 100644 freecad/workbench/__init__.py create mode 100644 pyproject.toml create mode 100644 solver/__init__.py create mode 100644 solver/datagen/__init__.py create mode 100644 solver/datasets/__init__.py create mode 100644 solver/evaluation/__init__.py create mode 100644 solver/inference/__init__.py create mode 100644 solver/models/__init__.py create mode 100644 solver/training/__init__.py create mode 100644 tests/__init__.py diff --git a/.gitea/workflows/ci.yaml b/.gitea/workflows/ci.yaml new file mode 100644 index 0000000..2dc42d0 --- /dev/null +++ b/.gitea/workflows/ci.yaml @@ -0,0 +1,65 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install ruff mypy + pip install -e ".[dev]" || pip install ruff mypy numpy + + - name: Ruff check + run: ruff check solver/ freecad/ tests/ scripts/ + + - name: Ruff format check + run: ruff format --check solver/ freecad/ tests/ scripts/ + + type-check: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install mypy numpy + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install torch-geometric + pip install -e ".[dev]" + + - name: Mypy + run: mypy solver/ freecad/ + + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + pip install torch --index-url https://download.pytorch.org/whl/cpu + pip install torch-geometric + pip install -e ".[train,dev]" + + - name: Run tests + run: pytest tests/ freecad/tests/ -v --tb=short diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..67e9723 --- /dev/null +++ b/.gitignore @@ -0,0 +1,48 @@ +# Python +__pycache__/ +*.py[cod] +*$py.class +*.egg-info/ +dist/ +build/ +*.egg + +# Virtual environments +.venv/ +venv/ + +# IDE +.vscode/ +.idea/ +*.swp +*.swo +*~ + +# mypy / ruff / pytest +.mypy_cache/ +.ruff_cache/ +.pytest_cache/ + +# Data (large files tracked separately) +data/synthetic/*.pt +data/fusion360/*.json +data/fusion360/*.step +data/processed/*.pt +!data/**/.gitkeep + +# Model checkpoints +*.ckpt +*.pth +*.onnx +*.torchscript + +# Experiment tracking +wandb/ +runs/ + +# OS +.DS_Store +Thumbs.db + +# Environment +.env diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..69b8611 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,23 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.3.4 + hooks: + - id: ruff + args: [--fix] + - id: ruff-format + + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.8.0 + hooks: + - id: mypy + additional_dependencies: + - torch>=2.2 + - numpy>=1.26 + args: [--ignore-missing-imports] + + - repo: https://github.com/compilerla/conventional-pre-commit + rev: v3.1.0 + hooks: + - id: conventional-pre-commit + stages: [commit-msg] + args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert] diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ebbca0a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,61 @@ +FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base + +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 + +# System deps +RUN apt-get update && apt-get install -y --no-install-recommends \ + python3.11 python3.11-venv python3.11-dev python3-pip \ + git wget curl \ + # FreeCAD headless deps + freecad \ + libgl1-mesa-glx libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1 + +# Create venv +RUN python -m venv /opt/venv +ENV PATH="/opt/venv/bin:$PATH" + +# Install PyTorch with CUDA +RUN pip install --no-cache-dir \ + torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124 + +# Install PyG +RUN pip install --no-cache-dir \ + torch-geometric \ + pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \ + -f https://data.pyg.org/whl/torch-2.4.0+cu124.html + +WORKDIR /workspace + +# Install project +COPY pyproject.toml . +RUN pip install --no-cache-dir -e ".[train,dev]" || true + +COPY . . +RUN pip install --no-cache-dir -e ".[train,dev]" + +# ------------------------------------------------------------------- +FROM base AS cpu + +# CPU-only variant (for CI and non-GPU environments) +FROM python:3.11-slim AS cpu-only + +ENV PYTHONUNBUFFERED=1 + +RUN apt-get update && apt-get install -y --no-install-recommends \ + git freecad libgl1-mesa-glx libglib2.0-0 \ + && rm -rf /var/lib/apt/lists/* + +WORKDIR /workspace + +COPY pyproject.toml . +RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu +RUN pip install --no-cache-dir torch-geometric + +COPY . . +RUN pip install --no-cache-dir -e ".[train,dev]" + +CMD ["pytest", "tests/", "-v"] diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..18d1309 --- /dev/null +++ b/Makefile @@ -0,0 +1,48 @@ +.PHONY: train test lint data-gen export format type-check install dev clean help + +PYTHON ?= python +PYTEST ?= pytest +RUFF ?= ruff +MYPY ?= mypy + +help: ## Show this help + @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \ + awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}' + +install: ## Install core dependencies + pip install -e . + +dev: ## Install all dependencies including dev tools + pip install -e ".[train,dev]" + pre-commit install + pre-commit install --hook-type commit-msg + +train: ## Run training (pass CONFIG=path/to/config.yaml) + $(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG)) + +test: ## Run test suite + $(PYTEST) tests/ freecad/tests/ -v --tb=short + +lint: ## Run ruff linter + $(RUFF) check solver/ freecad/ tests/ scripts/ + +format: ## Format code with ruff + $(RUFF) format solver/ freecad/ tests/ scripts/ + $(RUFF) check --fix solver/ freecad/ tests/ scripts/ + +type-check: ## Run mypy type checker + $(MYPY) solver/ freecad/ + +data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml) + $(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG)) + +export: ## Export trained model for deployment + $(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL)) + +clean: ## Remove build artifacts and caches + rm -rf build/ dist/ *.egg-info/ + rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/ + find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true + find . -type f -name "*.pyc" -delete 2>/dev/null || true + +check: lint type-check test ## Run all checks (lint, type-check, test) diff --git a/README.md b/README.md index e69de29..761eb0b 100644 --- a/README.md +++ b/README.md @@ -0,0 +1,71 @@ +# kindred-solver + +Assembly constraint prediction via GNN. Produces a trained model embedded in a FreeCAD workbench (Kindred Create library), later integrated into vanilla Create. + +## Overview + +`kindred-solver` predicts whether assembly constraints (joints) are independent or redundant using graph neural networks. Given an assembly graph where bodies are nodes and joints are edges, the model classifies each constraint and reports degrees of freedom per body. + +## Repository Structure + +``` +kindred-solver/ +├── solver/ # Core library +│ ├── datagen/ # Synthetic data generation (pebble game) +│ ├── datasets/ # PyG dataset adapters +│ ├── models/ # GNN architectures (GIN, GAT, NNConv) +│ ├── training/ # Training loops and configs +│ ├── evaluation/ # Metrics and visualization +│ └── inference/ # Runtime prediction API +├── freecad/ # FreeCAD integration +│ ├── workbench/ # FreeCAD workbench addon +│ ├── bridge/ # FreeCAD <-> solver interface +│ └── tests/ # Integration tests +├── export/ # Model packaging for Create +├── configs/ # Hydra configs (dataset, model, training, export) +├── scripts/ # CLI utilities +├── data/ # Datasets (not committed) +├── tests/ # Unit and integration tests +└── docs/ # Documentation +``` + +## Setup + +### Install (development) + +```bash +pip install -e ".[train,dev]" +pre-commit install +pre-commit install --hook-type commit-msg +``` + +### Using Make + +```bash +make help # show all targets +make dev # install all deps + pre-commit hooks +make test # run tests +make lint # run ruff linter +make type-check # run mypy +make check # lint + type-check + test +make train # run training +make data-gen # generate synthetic data +make export # export model +``` + +### Using Docker + +```bash +# GPU training +docker compose up train + +# Run tests (CPU) +docker compose up test + +# Generate data +docker compose up data-gen +``` + +## License + +Apache 2.0 diff --git a/configs/dataset/fusion360.yaml b/configs/dataset/fusion360.yaml new file mode 100644 index 0000000..68b7b21 --- /dev/null +++ b/configs/dataset/fusion360.yaml @@ -0,0 +1,12 @@ +# Fusion 360 Gallery dataset config +name: fusion360 +data_dir: data/fusion360 +output_dir: data/processed + +splits: + train: 0.8 + val: 0.1 + test: 0.1 + +stratify_by: complexity +seed: 42 diff --git a/configs/dataset/synthetic.yaml b/configs/dataset/synthetic.yaml new file mode 100644 index 0000000..8450ad8 --- /dev/null +++ b/configs/dataset/synthetic.yaml @@ -0,0 +1,24 @@ +# Synthetic dataset generation config +name: synthetic +num_assemblies: 100000 +output_dir: data/synthetic + +complexity_distribution: + simple: 0.4 # 2-5 bodies + medium: 0.4 # 6-15 bodies + complex: 0.2 # 16-50 bodies + +body_count: + min: 2 + max: 50 + +templates: + - chain + - tree + - loop + - star + - mixed + +grounded_ratio: 0.5 +seed: 42 +num_workers: 4 diff --git a/configs/export/production.yaml b/configs/export/production.yaml new file mode 100644 index 0000000..1be79ec --- /dev/null +++ b/configs/export/production.yaml @@ -0,0 +1,25 @@ +# Production model export config +model_checkpoint: checkpoints/finetune/best_val_loss.ckpt +output_dir: export/ + +formats: + onnx: + enabled: true + opset_version: 17 + dynamic_axes: true + torchscript: + enabled: true + +model_card: + version: "0.1.0" + architecture: baseline + training_data: + - synthetic_100k + - fusion360_gallery + +size_budget_mb: 50 + +inference: + device: cpu + batch_size: 1 + confidence_threshold: 0.8 diff --git a/configs/model/baseline.yaml b/configs/model/baseline.yaml new file mode 100644 index 0000000..ebefe9b --- /dev/null +++ b/configs/model/baseline.yaml @@ -0,0 +1,24 @@ +# Baseline GIN model config +name: baseline +architecture: gin + +encoder: + num_layers: 3 + hidden_dim: 128 + dropout: 0.1 + +node_features_dim: 22 +edge_features_dim: 22 + +heads: + edge_classification: + enabled: true + hidden_dim: 64 + graph_classification: + enabled: true + num_classes: 4 # rigid, under, over, mixed + joint_type: + enabled: true + num_classes: 12 + dof_regression: + enabled: true diff --git a/configs/model/gat.yaml b/configs/model/gat.yaml new file mode 100644 index 0000000..6592bc1 --- /dev/null +++ b/configs/model/gat.yaml @@ -0,0 +1,28 @@ +# Advanced GAT model config +name: gat_solver +architecture: gat + +encoder: + num_layers: 4 + hidden_dim: 256 + num_heads: 8 + dropout: 0.1 + residual: true + +node_features_dim: 22 +edge_features_dim: 22 + +heads: + edge_classification: + enabled: true + hidden_dim: 128 + graph_classification: + enabled: true + num_classes: 4 + joint_type: + enabled: true + num_classes: 12 + dof_regression: + enabled: true + dof_tracking: + enabled: true diff --git a/configs/training/finetune.yaml b/configs/training/finetune.yaml new file mode 100644 index 0000000..09e8bae --- /dev/null +++ b/configs/training/finetune.yaml @@ -0,0 +1,45 @@ +# Fine-tuning on real data config +phase: finetune + +dataset: fusion360 +model: baseline + +pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt + +optimizer: + name: adamw + lr: 1e-5 + weight_decay: 1e-4 + +scheduler: + name: cosine_annealing + T_max: 50 + eta_min: 1e-7 + +training: + epochs: 50 + batch_size: 32 + gradient_clip: 1.0 + early_stopping_patience: 10 + amp: true + freeze_encoder: false # set true for frozen encoder experiment + +loss: + edge_weight: 1.0 + graph_weight: 0.5 + joint_type_weight: 0.3 + dof_weight: 0.2 + redundant_penalty: 2.0 + +checkpointing: + save_best_val_loss: true + save_best_val_accuracy: true + save_every_n_epochs: 5 + checkpoint_dir: checkpoints/finetune + +logging: + backend: wandb + project: kindred-solver + log_every_n_steps: 20 + +seed: 42 diff --git a/configs/training/pretrain.yaml b/configs/training/pretrain.yaml new file mode 100644 index 0000000..90cd232 --- /dev/null +++ b/configs/training/pretrain.yaml @@ -0,0 +1,42 @@ +# Synthetic pre-training config +phase: pretrain + +dataset: synthetic +model: baseline + +optimizer: + name: adamw + lr: 1e-3 + weight_decay: 1e-4 + +scheduler: + name: cosine_annealing + T_max: 100 + eta_min: 1e-6 + +training: + epochs: 100 + batch_size: 64 + gradient_clip: 1.0 + early_stopping_patience: 10 + amp: true + +loss: + edge_weight: 1.0 + graph_weight: 0.5 + joint_type_weight: 0.3 + dof_weight: 0.2 + redundant_penalty: 2.0 # safety loss multiplier + +checkpointing: + save_best_val_loss: true + save_best_val_accuracy: true + save_every_n_epochs: 10 + checkpoint_dir: checkpoints/pretrain + +logging: + backend: wandb # or tensorboard + project: kindred-solver + log_every_n_steps: 50 + +seed: 42 diff --git a/data/fusion360/.gitkeep b/data/fusion360/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/processed/.gitkeep b/data/processed/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/splits/.gitkeep b/data/splits/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/data/synthetic/.gitkeep b/data/synthetic/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..dfe8b4c --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,39 @@ +services: + train: + build: + context: . + dockerfile: Dockerfile + target: base + volumes: + - .:/workspace + - ./data:/workspace/data + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: all + capabilities: [gpu] + command: make train + environment: + - CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0} + - WANDB_API_KEY=${WANDB_API_KEY:-} + + test: + build: + context: . + dockerfile: Dockerfile + target: cpu-only + volumes: + - .:/workspace + command: make check + + data-gen: + build: + context: . + dockerfile: Dockerfile + target: base + volumes: + - .:/workspace + - ./data:/workspace/data + command: make data-gen diff --git a/docs/.gitkeep b/docs/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/export/.gitkeep b/export/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/freecad/__init__.py b/freecad/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/bridge/__init__.py b/freecad/bridge/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/tests/__init__.py b/freecad/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/freecad/workbench/__init__.py b/freecad/workbench/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..439159d --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,97 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[project] +name = "kindred-solver" +version = "0.1.0" +description = "Assembly constraint prediction via GNN for Kindred Create" +readme = "README.md" +license = "Apache-2.0" +requires-python = ">=3.11" +authors = [ + { name = "Kindred Systems" }, +] +classifiers = [ + "Development Status :: 3 - Alpha", + "License :: OSI Approved :: Apache Software License", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Topic :: Scientific/Engineering", +] +dependencies = [ + "torch>=2.2", + "torch-geometric>=2.5", + "numpy>=1.26", + "scipy>=1.12", +] + +[project.optional-dependencies] +train = [ + "wandb>=0.16", + "tensorboard>=2.16", + "hydra-core>=1.3", + "omegaconf>=2.3", + "matplotlib>=3.8", + "networkx>=3.2", +] +freecad = [ + "pyside6>=6.6", +] +dev = [ + "pytest>=8.0", + "pytest-cov>=4.1", + "ruff>=0.3", + "mypy>=1.8", + "pre-commit>=3.6", +] + +[project.urls] +Repository = "https://git.kindred-systems.com/kindred/solver" + +[tool.hatch.build.targets.wheel] +packages = ["solver", "freecad"] + +[tool.ruff] +target-version = "py311" +line-length = 100 + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "W", # pycodestyle warnings + "F", # pyflakes + "I", # isort + "N", # pep8-naming + "UP", # pyupgrade + "B", # flake8-bugbear + "SIM", # flake8-simplify + "TCH", # flake8-type-checking + "RUF", # ruff-specific +] + +[tool.ruff.lint.isort] +known-first-party = ["solver", "freecad"] + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_configs = true +disallow_untyped_defs = true +check_untyped_defs = true + +[[tool.mypy.overrides]] +module = [ + "torch.*", + "torch_geometric.*", + "scipy.*", + "wandb.*", + "hydra.*", + "omegaconf.*", +] +ignore_missing_imports = true + +[tool.pytest.ini_options] +testpaths = ["tests", "freecad/tests"] +addopts = "-v --tb=short" diff --git a/solver/__init__.py b/solver/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/datasets/__init__.py b/solver/datasets/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/evaluation/__init__.py b/solver/evaluation/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/inference/__init__.py b/solver/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/models/__init__.py b/solver/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/solver/training/__init__.py b/solver/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 From 1b6135129ed169b9acd4b7115aa56245923e8ab6 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:43:19 -0600 Subject: [PATCH 03/12] feat: port shared types to solver/datagen/types.py Port JointType, RigidBody, Joint, PebbleState, and ConstraintAnalysis from data/synthetic/pebble-game.py into the solver package. - Add __all__ export list - Put typing.Any behind TYPE_CHECKING (ruff TCH003) - Parameterize list[dict] as list[dict[str, Any]] (mypy strict) - Re-export all types from solver.datagen.__init__ Closes #1 --- solver/datagen/__init__.py | 17 +++++ solver/datagen/types.py | 137 +++++++++++++++++++++++++++++++++++++ 2 files changed, 154 insertions(+) create mode 100644 solver/datagen/types.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index e69de29..6e53b70 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -0,0 +1,17 @@ +"""Data generation utilities for assembly constraint training data.""" + +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + PebbleState, + RigidBody, +) + +__all__ = [ + "ConstraintAnalysis", + "Joint", + "JointType", + "PebbleState", + "RigidBody", +] diff --git a/solver/datagen/types.py b/solver/datagen/types.py new file mode 100644 index 0000000..754e5dc --- /dev/null +++ b/solver/datagen/types.py @@ -0,0 +1,137 @@ +"""Shared data types for assembly constraint analysis. + +Types ported from the pebble-game synthetic data generator for reuse +across the solver package (data generation, training, inference). +""" + +from __future__ import annotations + +import enum +from dataclasses import dataclass, field +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "ConstraintAnalysis", + "Joint", + "JointType", + "PebbleState", + "RigidBody", +] + + +# --------------------------------------------------------------------------- +# Joint definitions: each joint type removes a known number of DOF +# --------------------------------------------------------------------------- + + +class JointType(enum.Enum): + """Standard CAD joint types with their DOF-removal counts. + + Each joint between two 6-DOF rigid bodies removes a specific number + of relative degrees of freedom. In the body-bar-hinge multigraph + representation, each joint maps to a number of edges equal to the + DOF it removes. + + DOF removed = number of scalar constraint equations the joint imposes. + """ + + FIXED = 6 # Locks all relative motion + REVOLUTE = 5 # Allows rotation about one axis + CYLINDRICAL = 4 # Allows rotation + translation along one axis + SLIDER = 5 # Allows translation along one axis (prismatic) + BALL = 3 # Allows rotation about a point (spherical) + PLANAR = 3 # Allows 2D translation + rotation normal to plane + SCREW = 5 # Coupled rotation-translation (helical) + UNIVERSAL = 4 # Two rotational DOF (Cardan/U-joint) + PARALLEL = 3 # Forces parallel orientation (3 rotation constraints) + PERPENDICULAR = 1 # Single angular constraint + DISTANCE = 1 # Single scalar distance constraint + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- + + +@dataclass +class RigidBody: + """A rigid body in the assembly with pose and geometry info.""" + + body_id: int + position: np.ndarray = field(default_factory=lambda: np.zeros(3)) + orientation: np.ndarray = field(default_factory=lambda: np.eye(3)) + + # Anchor points for joints, in local frame + # Populated when joints reference specific geometry + local_anchors: dict[str, np.ndarray] = field(default_factory=dict) + + +@dataclass +class Joint: + """A joint connecting two rigid bodies.""" + + joint_id: int + body_a: int # Index of first body + body_b: int # Index of second body + joint_type: JointType + + # Joint parameters in world frame + anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3)) + anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3)) + axis: np.ndarray = field( + default_factory=lambda: np.array([0.0, 0.0, 1.0]), + ) + + # For screw joints + pitch: float = 0.0 + + +@dataclass +class PebbleState: + """Tracks the state of the pebble game on the multigraph.""" + + # Number of free pebbles per body (vertex). Starts at 6. + free_pebbles: dict[int, int] = field(default_factory=dict) + + # Directed edges: edge_id -> (source_body, target_body) + # Edge is directed away from the body that "spent" a pebble. + directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict) + + # Track which edges are independent vs redundant + independent_edges: set[int] = field(default_factory=set) + redundant_edges: set[int] = field(default_factory=set) + + # Adjacency: body_id -> set of (edge_id, neighbor_body_id) + # Following directed edges *towards* a body (incoming edges) + incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + # Outgoing edges from a body + outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict) + + +@dataclass +class ConstraintAnalysis: + """Results of analyzing an assembly's constraint system.""" + + # Pebble game (combinatorial) results + combinatorial_dof: int + combinatorial_internal_dof: int + combinatorial_redundant: int + combinatorial_classification: str + per_edge_results: list[dict[str, Any]] + + # Numerical (Jacobian) results + jacobian_rank: int + jacobian_nullity: int # = 6n - rank = total DOF + jacobian_internal_dof: int # = nullity - 6 + numerically_dependent: list[int] + + # Combined + geometric_degeneracies: int # = combinatorial_independent - jacobian_rank + is_rigid: bool + is_minimally_rigid: bool From 35d4ef736f27c826c9f17abc5f9513c932cfb1eb Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:47:36 -0600 Subject: [PATCH 04/12] feat: port PebbleGame3D to solver/datagen/pebble_game.py Port the (6,6)-pebble game implementation from data/synthetic/pebble-game.py. Imports shared types from solver.datagen.types. No behavioral changes. - Full type annotations on all methods (mypy strict) - Ruff-compliant: ternary, combined if, unpacking - Re-exported from solver.datagen.__init__ Closes #2 --- solver/datagen/__init__.py | 2 + solver/datagen/pebble_game.py | 258 ++++++++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 solver/datagen/pebble_game.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 6e53b70..6504939 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,5 +1,6 @@ """Data generation utilities for assembly constraint training data.""" +from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( ConstraintAnalysis, Joint, @@ -12,6 +13,7 @@ __all__ = [ "ConstraintAnalysis", "Joint", "JointType", + "PebbleGame3D", "PebbleState", "RigidBody", ] diff --git a/solver/datagen/pebble_game.py b/solver/datagen/pebble_game.py new file mode 100644 index 0000000..655cd13 --- /dev/null +++ b/solver/datagen/pebble_game.py @@ -0,0 +1,258 @@ +"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis. + +Implements the pebble game algorithm adapted for CAD assembly constraint +graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints +between bodies remove DOF according to their type. + +The pebble game provides a fast combinatorial *necessary* condition for +rigidity via Tay's theorem. It does not detect geometric degeneracies — +use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient* +condition. + +References: + - Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008 + - Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity + Percolation: The Pebble Game", J. Comput. Phys., 1997 + - Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984 +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from solver.datagen.types import Joint, PebbleState + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["PebbleGame3D"] + + +class PebbleGame3D: + """Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks. + + For body-bar-hinge structures in 3D, Tay's theorem states that a + multigraph G on n vertices is generically minimally rigid iff: + |E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2) + + The (6,6)-pebble game tests this sparsity condition incrementally. + Each vertex starts with 6 pebbles (representing 6 DOF). To insert + an edge, we need to collect 6+1=7 pebbles on its two endpoints. + If we can, the edge is independent (removes a DOF). If not, it's + redundant (overconstrained). + + In the CAD assembly context: + - Vertices = rigid bodies + - Edges = scalar constraints from joints + - A revolute joint (5 DOF removed) maps to 5 multigraph edges + - A fixed joint (6 DOF removed) maps to 6 multigraph edges + """ + + K = 6 # Pebbles per vertex (DOF per rigid body in 3D) + L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge + + def __init__(self) -> None: + self.state = PebbleState() + self._edge_counter = 0 + self._bodies: set[int] = set() + + def add_body(self, body_id: int) -> None: + """Register a rigid body (vertex) with K=6 free pebbles.""" + if body_id in self._bodies: + return + self._bodies.add(body_id) + self.state.free_pebbles[body_id] = self.K + self.state.incoming[body_id] = set() + self.state.outgoing[body_id] = set() + + def add_joint(self, joint: Joint) -> list[dict[str, Any]]: + """Expand a joint into multigraph edges and test each for independence. + + A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph. + Each edge is tested individually via the pebble game. + + Returns a list of dicts, one per scalar constraint, with: + - edge_id: int + - independent: bool + - dof_remaining: int (total free pebbles after this edge) + """ + self.add_body(joint.body_a) + self.add_body(joint.body_b) + + num_constraints = joint.joint_type.value + results: list[dict[str, Any]] = [] + + for i in range(num_constraints): + edge_id = self._edge_counter + self._edge_counter += 1 + + independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b) + total_free = sum(self.state.free_pebbles.values()) + + results.append( + { + "edge_id": edge_id, + "joint_id": joint.joint_id, + "constraint_index": i, + "independent": independent, + "dof_remaining": total_free, + } + ) + + return results + + def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool: + """Try to insert a directed edge between u and v. + + The edge is accepted (independent) iff we can collect L+1 = 7 + pebbles on the two endpoints {u, v} combined. + + If accepted, one pebble is consumed and the edge is directed + away from the vertex that gives up the pebble. + """ + # Count current free pebbles on u and v + available = self.state.free_pebbles[u] + self.state.free_pebbles[v] + + # Try to gather enough pebbles via DFS reachability search + if available < self.L + 1: + needed = (self.L + 1) - available + # Try to free pebbles by searching from u first, then v + for target in (u, v): + while needed > 0: + found = self._search_and_collect(target, frozenset({u, v})) + if not found: + break + needed -= 1 + + # Recheck after collection attempts + available = self.state.free_pebbles[u] + self.state.free_pebbles[v] + + if available >= self.L + 1: + # Accept: consume a pebble from whichever endpoint has one + source = u if self.state.free_pebbles[u] > 0 else v + + self.state.free_pebbles[source] -= 1 + self.state.directed_edges[edge_id] = (source, v if source == u else u) + self.state.outgoing[source].add((edge_id, v if source == u else u)) + target = v if source == u else u + self.state.incoming[target].add((edge_id, source)) + self.state.independent_edges.add(edge_id) + return True + else: + # Reject: edge is redundant (overconstrained) + self.state.redundant_edges.add(edge_id) + return False + + def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool: + """DFS to find a free pebble reachable from *target* and move it. + + Follows directed edges *backwards* (from destination to source) + to find a vertex with a free pebble that isn't in *forbidden*. + When found, reverses the path to move the pebble to *target*. + + Returns True if a pebble was successfully moved to target. + """ + # BFS/DFS through the directed graph following outgoing edges + # from target. An outgoing edge (target -> w) means target spent + # a pebble on that edge. If we can find a vertex with a free + # pebble, we reverse edges along the path to move it. + + visited: set[int] = set() + # Stack: (current_vertex, path_of_edge_ids_to_reverse) + stack: list[tuple[int, list[int]]] = [(target, [])] + + while stack: + current, path = stack.pop() + if current in visited: + continue + visited.add(current) + + # Check if current vertex (not in forbidden, not target) + # has a free pebble + if ( + current != target + and current not in forbidden + and self.state.free_pebbles[current] > 0 + ): + # Found a pebble — reverse the path + self._reverse_path(path, current) + return True + + # Follow outgoing edges from current vertex + for eid, neighbor in self.state.outgoing.get(current, set()): + if neighbor not in visited: + stack.append((neighbor, [*path, eid])) + + return False + + def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None: + """Reverse directed edges along a path, moving a pebble to the start. + + The pebble at *pebble_source* is consumed by the last edge in + the path, and a pebble is freed at the path's start vertex. + """ + if not edge_ids: + return + + # Reverse each edge in the path + for eid in edge_ids: + old_source, old_target = self.state.directed_edges[eid] + + # Remove from adjacency + self.state.outgoing[old_source].discard((eid, old_target)) + self.state.incoming[old_target].discard((eid, old_source)) + + # Reverse direction + self.state.directed_edges[eid] = (old_target, old_source) + self.state.outgoing[old_target].add((eid, old_source)) + self.state.incoming[old_source].add((eid, old_target)) + + # Move pebble counts: source loses one, first vertex in path gains one + self.state.free_pebbles[pebble_source] -= 1 + + # After all reversals, the vertex at the beginning of the + # search path gains a pebble + _first_src, first_tgt = self.state.directed_edges[edge_ids[0]] + self.state.free_pebbles[first_tgt] += 1 + + def get_dof(self) -> int: + """Total remaining DOF = sum of free pebbles. + + For a fully rigid assembly, this should be 6 (the trivial rigid + body motions of the whole assembly). Internal DOF = total - 6. + """ + return sum(self.state.free_pebbles.values()) + + def get_internal_dof(self) -> int: + """Internal (non-trivial) degrees of freedom.""" + return max(0, self.get_dof() - 6) + + def is_rigid(self) -> bool: + """Combinatorial rigidity check: rigid iff at most 6 pebbles remain.""" + return self.get_dof() <= self.L + + def get_redundant_count(self) -> int: + """Number of redundant (overconstrained) scalar constraints.""" + return len(self.state.redundant_edges) + + def classify_assembly(self, *, grounded: bool = False) -> str: + """Classify the assembly state. + + Args: + grounded: If True, the baseline trivial DOF is 0 (not 6), + because the ground body's 6 DOF were removed. + """ + total_dof = self.get_dof() + redundant = self.get_redundant_count() + baseline = 0 if grounded else self.L + + if redundant > 0 and total_dof > baseline: + return "mixed" # Both under and over-constrained regions + elif redundant > 0: + return "overconstrained" + elif total_dof > baseline: + return "underconstrained" + elif total_dof == baseline: + return "well-constrained" + else: + return "overconstrained" From 455b6318d94951fa8fb0709bb8e47be9478d3153 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:50:16 -0600 Subject: [PATCH 05/12] feat: port JacobianVerifier to solver/datagen/jacobian.py Port the constraint Jacobian builder and numerical rank verifier from data/synthetic/pebble-game.py. All 11 joint type builders, SVD rank computation, and incremental dependency detection. - Full type annotations (mypy strict) - Ruff lint and format clean - Re-exported from solver.datagen.__init__ Closes #3 --- solver/datagen/__init__.py | 2 + solver/datagen/jacobian.py | 517 +++++++++++++++++++++++++++++++++++++ 2 files changed, 519 insertions(+) create mode 100644 solver/datagen/jacobian.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 6504939..a48a988 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,5 +1,6 @@ """Data generation utilities for assembly constraint training data.""" +from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( ConstraintAnalysis, @@ -11,6 +12,7 @@ from solver.datagen.types import ( __all__ = [ "ConstraintAnalysis", + "JacobianVerifier", "Joint", "JointType", "PebbleGame3D", diff --git a/solver/datagen/jacobian.py b/solver/datagen/jacobian.py new file mode 100644 index 0000000..09bb769 --- /dev/null +++ b/solver/datagen/jacobian.py @@ -0,0 +1,517 @@ +"""Numerical Jacobian rank verification for assembly constraint analysis. + +Builds the constraint Jacobian matrix and analyzes its numerical rank +to detect geometric degeneracies that the combinatorial pebble game +cannot identify (e.g., parallel revolute axes creating hidden dependencies). + +References: + - Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D" +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.types import Joint, JointType, RigidBody + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["JacobianVerifier"] + + +class JacobianVerifier: + """Builds and analyzes the constraint Jacobian for numerical rank check. + + The pebble game gives a combinatorial *necessary* condition for + rigidity. However, geometric special cases (e.g., all revolute axes + parallel, creating a hidden dependency) require numerical verification. + + For each joint, we construct the constraint Jacobian rows that map + the 6n-dimensional generalized velocity vector to the constraint + violation rates. The rank of this Jacobian equals the number of + truly independent constraints. + + The generalized velocity vector for n bodies is:: + + v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z] + + Each scalar constraint C_i contributes one row to J such that:: + + dC_i/dt = J_i @ v = 0 + """ + + def __init__(self, bodies: list[RigidBody]) -> None: + self.bodies = {b.body_id: b for b in bodies} + self.body_index = {b.body_id: i for i, b in enumerate(bodies)} + self.n_bodies = len(bodies) + self.jacobian_rows: list[np.ndarray] = [] + self.row_labels: list[dict[str, Any]] = [] + + def _body_cols(self, body_id: int) -> tuple[int, int]: + """Return the column range [start, end) for a body in J.""" + idx = self.body_index[body_id] + return idx * 6, (idx + 1) * 6 + + def add_joint_constraints(self, joint: Joint) -> int: + """Add Jacobian rows for all scalar constraints of a joint. + + Returns the number of rows added. + """ + builder = { + JointType.FIXED: self._build_fixed, + JointType.REVOLUTE: self._build_revolute, + JointType.CYLINDRICAL: self._build_cylindrical, + JointType.SLIDER: self._build_slider, + JointType.BALL: self._build_ball, + JointType.PLANAR: self._build_planar, + JointType.DISTANCE: self._build_distance, + JointType.PARALLEL: self._build_parallel, + JointType.PERPENDICULAR: self._build_perpendicular, + JointType.UNIVERSAL: self._build_universal, + JointType.SCREW: self._build_screw, + } + + rows_before = len(self.jacobian_rows) + builder[joint.joint_type](joint) + return len(self.jacobian_rows) - rows_before + + def _make_row(self) -> np.ndarray: + """Create a zero row of width 6*n_bodies.""" + return np.zeros(6 * self.n_bodies) + + def _skew(self, v: np.ndarray) -> np.ndarray: + """Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``.""" + return np.array( + [ + [0, -v[2], v[1]], + [v[2], 0, -v[0]], + [-v[1], v[0], 0], + ] + ) + + # --- Ball-and-socket (spherical) joint: 3 translation constraints --- + + def _build_ball(self, joint: Joint) -> None: + """Ball joint: coincident point constraint. + + ``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0`` + (3 equations) + + Jacobian rows (for each of x, y, z): + body_a linear: -I + body_a angular: +skew(R_a @ r_a) + body_b linear: +I + body_b angular: -skew(R_b @ r_b) + """ + # Use anchor positions directly as world-frame offsets + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + e = np.zeros(3) + e[axis_idx] = 1.0 + + row[col_a_start : col_a_start + 3] = -e + row[col_a_start + 3 : col_a_end] = np.cross(r_a, e) + + row[col_b_start : col_b_start + 3] = e + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "ball_translation", + "axis": axis_idx, + } + ) + + # --- Fixed joint: 3 translation + 3 rotation constraints --- + + def _build_fixed(self, joint: Joint) -> None: + """Fixed joint = ball joint + locked rotation. + + Translation part: same as ball joint (3 rows). + Rotation part: relative angular velocity must be zero (3 rows). + """ + self._build_ball(joint) + + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "fixed_rotation", + "axis": axis_idx, + } + ) + + # --- Revolute (hinge) joint: 3 translation + 2 rotation constraints --- + + def _build_revolute(self, joint: Joint) -> None: + """Revolute joint: rotation only about one axis. + + Translation: same as ball (3 rows). + Rotation: relative angular velocity must be parallel to hinge axis. + """ + self._build_ball(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "revolute_rotation", + "perp_axis": i, + } + ) + + # --- Cylindrical joint: 2 translation + 2 rotation constraints --- + + def _build_cylindrical(self, joint: Joint) -> None: + """Cylindrical joint: allows rotation + translation along one axis. + + Translation: constrain motion perpendicular to axis (2 rows). + Rotation: constrain rotation perpendicular to axis (2 rows). + """ + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start : col_a_start + 3] = -t + row[col_a_start + 3 : col_a_end] = np.cross(r_a, t) + row[col_b_start : col_b_start + 3] = t + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "cylindrical_translation", + "perp_axis": i, + } + ) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "cylindrical_rotation", + "perp_axis": i, + } + ) + + # --- Slider (prismatic) joint: 2 translation + 3 rotation constraints --- + + def _build_slider(self, joint: Joint) -> None: + """Slider/prismatic joint: translation along one axis only. + + Translation: perpendicular translation constrained (2 rows). + Rotation: all relative rotation constrained (3 rows). + """ + axis = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(axis) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start : col_a_start + 3] = -t + row[col_a_start + 3 : col_a_end] = np.cross(r_a, t) + row[col_b_start : col_b_start + 3] = t + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "slider_translation", + "perp_axis": i, + } + ) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "slider_rotation", + "axis": axis_idx, + } + ) + + # --- Planar joint: 1 translation + 2 rotation constraints --- + + def _build_planar(self, joint: Joint) -> None: + """Planar joint: constrains to a plane. + + Translation: motion along plane normal constrained (1 row). + Rotation: rotation about axes in the plane constrained (2 rows). + """ + normal = joint.axis / np.linalg.norm(joint.axis) + t1, t2 = self._perpendicular_pair(normal) + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -normal + row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal) + row[col_b_start : col_b_start + 3] = normal + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal) + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "planar_translation", + } + ) + + for i, t in enumerate((t1, t2)): + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -t + row[col_b_start + 3 : col_b_start + 6] = t + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "planar_rotation", + "perp_axis": i, + } + ) + + # --- Distance constraint: 1 scalar --- + + def _build_distance(self, joint: Joint) -> None: + """Distance constraint: ``||p_b - p_a|| = d``. + + Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0`` + where ``direction = normalized(p_b - p_a)``. + """ + p_a = joint.anchor_a + p_b = joint.anchor_b + diff = p_b - p_a + dist = np.linalg.norm(diff) + direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist + + r_a = joint.anchor_a - self.bodies[joint.body_a].position + r_b = joint.anchor_b - self.bodies[joint.body_b].position + + col_a_start, col_a_end = self._body_cols(joint.body_a) + col_b_start, col_b_end = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -direction + row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction) + row[col_b_start : col_b_start + 3] = direction + row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction) + + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "distance", + } + ) + + # --- Parallel constraint: 3 rotation constraints --- + + def _build_parallel(self, joint: Joint) -> None: + """Parallel: all relative rotation constrained (same as fixed rotation). + + In practice only 2 of 3 are independent for a single axis, but + we emit 3 and let the rank check sort it out. + """ + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + for axis_idx in range(3): + row = self._make_row() + row[col_a_start + 3 + axis_idx] = -1.0 + row[col_b_start + 3 + axis_idx] = 1.0 + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "parallel_rotation", + "axis": axis_idx, + } + ) + + # --- Perpendicular constraint: 1 angular --- + + def _build_perpendicular(self, joint: Joint) -> None: + """Perpendicular: single dot-product angular constraint.""" + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -axis + row[col_b_start + 3 : col_b_start + 6] = axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "perpendicular", + } + ) + + # --- Universal (Cardan) joint: 3 translation + 1 rotation --- + + def _build_universal(self, joint: Joint) -> None: + """Universal joint: ball + one rotation constraint. + + Allows rotation about two axes, constrains rotation about the third. + """ + self._build_ball(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start + 3 : col_a_start + 6] = -axis + row[col_b_start + 3 : col_b_start + 6] = axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "universal_rotation", + } + ) + + # --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled --- + + def _build_screw(self, joint: Joint) -> None: + """Screw joint: coupled rotation-translation along axis. + + Like cylindrical but with a coupling constraint: + ``v_axial - pitch * w_axial = 0`` + """ + self._build_cylindrical(joint) + + axis = joint.axis / np.linalg.norm(joint.axis) + col_a_start, _ = self._body_cols(joint.body_a) + col_b_start, _ = self._body_cols(joint.body_b) + + row = self._make_row() + row[col_a_start : col_a_start + 3] = -axis + row[col_b_start : col_b_start + 3] = axis + row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis + row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis + self.jacobian_rows.append(row) + self.row_labels.append( + { + "joint_id": joint.joint_id, + "type": "screw_coupling", + } + ) + + # --- Utilities --- + + def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Generate two unit vectors perpendicular to *axis* and each other.""" + if abs(axis[0]) < 0.9: + t1 = np.cross(axis, np.array([1.0, 0, 0])) + else: + t1 = np.cross(axis, np.array([0, 1.0, 0])) + t1 /= np.linalg.norm(t1) + t2 = np.cross(axis, t1) + t2 /= np.linalg.norm(t2) + return t1, t2 + + def get_jacobian(self) -> np.ndarray: + """Return the full constraint Jacobian matrix.""" + if not self.jacobian_rows: + return np.zeros((0, 6 * self.n_bodies)) + return np.array(self.jacobian_rows) + + def numerical_rank(self, tol: float = 1e-8) -> int: + """Compute the numerical rank of the constraint Jacobian via SVD. + + This is the number of truly independent scalar constraints, + accounting for geometric degeneracies that the combinatorial + pebble game cannot detect. + """ + j = self.get_jacobian() + if j.size == 0: + return 0 + sv = np.linalg.svd(j, compute_uv=False) + return int(np.sum(sv > tol)) + + def find_dependencies(self, tol: float = 1e-8) -> list[int]: + """Identify which constraint rows are numerically dependent. + + Returns indices of rows that can be removed without changing + the Jacobian's rank. + """ + j = self.get_jacobian() + if j.size == 0: + return [] + + n_rows = j.shape[0] + dependent: list[int] = [] + + current = np.zeros((0, j.shape[1])) + current_rank = 0 + + for i in range(n_rows): + candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :] + sv = np.linalg.svd(candidate, compute_uv=False) + new_rank = int(np.sum(sv > tol)) + + if new_rank > current_rank: + current = candidate + current_rank = new_rank + else: + dependent.append(i) + + return dependent From 9a31df4988b6d5267717a8c85542908c04aca4df Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:52:03 -0600 Subject: [PATCH 06/12] feat: port analyze_assembly to solver/datagen/analysis.py Port the combined pebble game + Jacobian verification entry point from data/synthetic/pebble-game.py. Ties PebbleGame3D and JacobianVerifier together with virtual ground body support. - Optional[int] -> int | None (UP007) - GROUND_ID constant extracted to module level - Full type annotations (mypy strict) - Re-exported from solver.datagen.__init__ Closes #4 --- solver/datagen/__init__.py | 2 + solver/datagen/analysis.py | 130 +++++++++++++++++++++++++++++++++++++ 2 files changed, 132 insertions(+) create mode 100644 solver/datagen/analysis.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index a48a988..7d64bc0 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,5 +1,6 @@ """Data generation utilities for assembly constraint training data.""" +from solver.datagen.analysis import analyze_assembly from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -18,4 +19,5 @@ __all__ = [ "PebbleGame3D", "PebbleState", "RigidBody", + "analyze_assembly", ] diff --git a/solver/datagen/analysis.py b/solver/datagen/analysis.py new file mode 100644 index 0000000..f862802 --- /dev/null +++ b/solver/datagen/analysis.py @@ -0,0 +1,130 @@ +"""Combined pebble game + Jacobian verification analysis. + +Provides :func:`analyze_assembly`, the main entry point for full rigidity +analysis of an assembly using both combinatorial and numerical methods. +""" + +from __future__ import annotations + +import numpy as np + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +__all__ = ["analyze_assembly"] + +_GROUND_ID = -1 + + +def analyze_assembly( + bodies: list[RigidBody], + joints: list[Joint], + ground_body: int | None = None, +) -> ConstraintAnalysis: + """Full rigidity analysis of an assembly using both methods. + + Args: + bodies: List of rigid bodies in the assembly. + joints: List of joints connecting bodies. + ground_body: If set, this body is fixed (adds 6 implicit constraints). + + Returns: + ConstraintAnalysis with combinatorial and numerical results. + """ + # --- Pebble Game --- + pg = PebbleGame3D() + all_edge_results = [] + + # Add a virtual ground body (id=-1) if grounding is requested. + # Grounding body X means adding a fixed joint between X and + # the virtual ground. This properly lets the pebble game account + # for the 6 removed DOF without breaking invariants. + if ground_body is not None: + pg.add_body(_GROUND_ID) + + for body in bodies: + pg.add_body(body.body_id) + + if ground_body is not None: + ground_joint = Joint( + joint_id=-1, + body_a=ground_body, + body_b=_GROUND_ID, + joint_type=JointType.FIXED, + anchor_a=bodies[0].position if bodies else np.zeros(3), + anchor_b=bodies[0].position if bodies else np.zeros(3), + ) + pg.add_joint(ground_joint) + # Don't include ground joint edges in the output labels + # (they're infrastructure, not user constraints) + + for joint in joints: + results = pg.add_joint(joint) + all_edge_results.extend(results) + + combinatorial_independent = len(pg.state.independent_edges) + grounded = ground_body is not None + + # The virtual ground body contributes 6 pebbles to the total. + # Subtract those from the reported DOF for user-facing numbers. + raw_dof = pg.get_dof() + ground_offset = 6 if grounded else 0 + effective_dof = raw_dof - ground_offset + effective_internal_dof = max(0, effective_dof - (0 if grounded else 6)) + + combinatorial_classification = pg.classify_assembly(grounded=grounded) + + # --- Jacobian Verification --- + verifier = JacobianVerifier(bodies) + + for joint in joints: + verifier.add_joint_constraints(joint) + + # If grounded, remove the ground body's columns (fix its DOF) + j = verifier.get_jacobian() + if ground_body is not None and j.size > 0: + idx = verifier.body_index[ground_body] + cols_to_remove = list(range(idx * 6, (idx + 1) * 6)) + j = np.delete(j, cols_to_remove, axis=1) + + if j.size > 0: + sv = np.linalg.svd(j, compute_uv=False) + jacobian_rank = int(np.sum(sv > 1e-8)) + else: + jacobian_rank = 0 + + n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies) + jacobian_nullity = n_cols - jacobian_rank + + dependent = verifier.find_dependencies() + + # Adjust for ground + trivial_dof = 0 if ground_body is not None else 6 + jacobian_internal_dof = jacobian_nullity - trivial_dof + + geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank) + + # Rigidity: numerically rigid if nullity == trivial DOF + is_rigid = jacobian_nullity <= trivial_dof + is_minimally_rigid = is_rigid and len(dependent) == 0 + + return ConstraintAnalysis( + combinatorial_dof=effective_dof, + combinatorial_internal_dof=effective_internal_dof, + combinatorial_redundant=pg.get_redundant_count(), + combinatorial_classification=combinatorial_classification, + per_edge_results=all_edge_results, + jacobian_rank=jacobian_rank, + jacobian_nullity=jacobian_nullity, + jacobian_internal_dof=max(0, jacobian_internal_dof), + numerically_dependent=dependent, + geometric_degeneracies=geometric_degeneracies, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + ) From 831a10cdb4eb1b4691414f3ece88e96154994e54 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 13:54:32 -0600 Subject: [PATCH 07/12] feat: port SyntheticAssemblyGenerator to solver/datagen/generator.py Port chain, rigid, and overconstrained assembly generators plus the training batch generation from data/synthetic/pebble-game.py. - Refactored rng.choice on enums/callables to integer indexing (mypy) - Typed n_bodies_range as tuple[int, int] - Typed batch return as list[dict[str, Any]] - Full type annotations (mypy strict) - Re-exported from solver.datagen.__init__ Closes #5 --- solver/datagen/__init__.py | 2 + solver/datagen/generator.py | 252 ++++++++++++++++++++++++++++++++++++ 2 files changed, 254 insertions(+) create mode 100644 solver/datagen/generator.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 7d64bc0..bd6fc3c 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,6 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly +from solver.datagen.generator import SyntheticAssemblyGenerator from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -19,5 +20,6 @@ __all__ = [ "PebbleGame3D", "PebbleState", "RigidBody", + "SyntheticAssemblyGenerator", "analyze_assembly", ] diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py new file mode 100644 index 0000000..6c6bb0f --- /dev/null +++ b/solver/datagen/generator.py @@ -0,0 +1,252 @@ +"""Synthetic assembly graph generator for training data production. + +Generates assembly graphs with known constraint classifications using +the pebble game and Jacobian verification. Each assembly is fully labeled +with per-constraint independence flags and assembly-level classification. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["SyntheticAssemblyGenerator"] + + +class SyntheticAssemblyGenerator: + """Generates assembly graphs with known minimal constraint sets. + + Uses the pebble game to incrementally build assemblies, tracking + exactly which constraints are independent at each step. This produces + labeled training data: (assembly_graph, constraint_set, labels). + + Labels per constraint: + - independent: bool (does this constraint remove a DOF?) + - redundant: bool (is this constraint overconstrained?) + - minimal_set: bool (part of a minimal rigidity basis?) + """ + + def __init__(self, seed: int = 42) -> None: + self.rng = np.random.default_rng(seed) + + def generate_chain_assembly( + self, + n_bodies: int, + joint_type: JointType = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a serial kinematic chain. + + Simple but useful: each body connects to the next with the + specified joint type. Results in an underconstrained assembly + (serial chain is never rigid without closing loops). + """ + bodies = [] + joints = [] + + for i in range(n_bodies): + pos = np.array([i * 2.0, 0.0, 0.0]) + bodies.append(RigidBody(body_id=i, position=pos)) + + for i in range(n_bodies - 1): + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) + + joints.append( + Joint( + joint_id=i, + body_a=i, + body_b=i + 1, + joint_type=joint_type, + anchor_a=anchor, + anchor_b=anchor, + axis=axis, + ) + ) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_rigid_assembly( + self, n_bodies: int + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a minimally rigid assembly by adding joints until rigid. + + Strategy: start with fixed joints on a spanning tree (guarantees + rigidity), then randomly relax some to weaker joint types while + maintaining rigidity via the pebble game check. + """ + bodies = [] + for i in range(n_bodies): + pos = self.rng.uniform(-5, 5, size=3) + bodies.append(RigidBody(body_id=i, position=pos)) + + # Build spanning tree with fixed joints (overconstrained) + joints: list[Joint] = [] + for i in range(1, n_bodies): + parent = self.rng.integers(0, i) + mid = (bodies[i].position + bodies[parent].position) / 2 + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + joints.append( + Joint( + joint_id=i - 1, + body_a=int(parent), + body_b=i, + joint_type=JointType.FIXED, + anchor_a=mid, + anchor_b=mid, + axis=axis, + ) + ) + + # Try relaxing joints to weaker types while maintaining rigidity + weaker_types = [ + JointType.REVOLUTE, + JointType.CYLINDRICAL, + JointType.BALL, + ] + + for idx in self.rng.permutation(len(joints)): + original_type = joints[idx].joint_type + for candidate in weaker_types: + joints[idx].joint_type = candidate + analysis = analyze_assembly(bodies, joints, ground_body=0) + if analysis.is_rigid: + break # Keep the weaker type + else: + joints[idx].joint_type = original_type + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_overconstrained_assembly( + self, + n_bodies: int, + extra_joints: int = 2, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate an assembly with known redundant constraints. + + Starts with a rigid assembly, then adds extra joints that + the pebble game will flag as redundant. + """ + bodies, joints, _ = self.generate_rigid_assembly(n_bodies) + + joint_id = len(joints) + for _ in range(extra_joints): + a, b = self.rng.choice(n_bodies, size=2, replace=False) + mid = (bodies[a].position + bodies[b].position) / 2 + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + + _overcon_types = [ + JointType.REVOLUTE, + JointType.FIXED, + JointType.BALL, + ] + jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))] + joints.append( + Joint( + joint_id=joint_id, + body_a=int(a), + body_b=int(b), + joint_type=jtype, + anchor_a=mid, + anchor_b=mid, + axis=axis, + ) + ) + joint_id += 1 + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_training_batch( + self, + batch_size: int = 100, + n_bodies_range: tuple[int, int] = (3, 8), + ) -> list[dict[str, Any]]: + """Generate a batch of labeled training examples. + + Each example contains: + - bodies: list of body positions/orientations + - joints: list of joints with types and parameters + - labels: per-joint independence/redundancy flags + - assembly_label: overall classification + """ + examples: list[dict[str, Any]] = [] + + for i in range(batch_size): + n = int(self.rng.integers(*n_bodies_range)) + gen_idx = int(self.rng.integers(3)) + + if gen_idx == 0: + _chain_types = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + ] + jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] + bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + elif gen_idx == 1: + bodies, joints, analysis = self.generate_rigid_assembly(n) + else: + extra = int(self.rng.integers(1, 4)) + bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + + # Build per-joint labels from edge results + joint_labels: dict[int, dict[str, int]] = {} + for result in analysis.per_edge_results: + jid = result["joint_id"] + if jid not in joint_labels: + joint_labels[jid] = { + "independent_constraints": 0, + "redundant_constraints": 0, + "total_constraints": 0, + } + joint_labels[jid]["total_constraints"] += 1 + if result["independent"]: + joint_labels[jid]["independent_constraints"] += 1 + else: + joint_labels[jid]["redundant_constraints"] += 1 + + examples.append( + { + "example_id": i, + "n_bodies": len(bodies), + "n_joints": len(joints), + "body_positions": [b.position.tolist() for b in bodies], + "joints": [ + { + "joint_id": j.joint_id, + "body_a": j.body_a, + "body_b": j.body_b, + "type": j.joint_type.name, + "axis": j.axis.tolist(), + } + for j in joints + ], + "joint_labels": joint_labels, + "assembly_classification": analysis.combinatorial_classification, + "is_rigid": analysis.is_rigid, + "is_minimally_rigid": analysis.is_minimally_rigid, + "internal_dof": analysis.jacobian_internal_dof, + "geometric_degeneracies": analysis.geometric_degeneracies, + } + ) + + return examples From dc742bfc82727581cd8b4f7ba99aae5e5653d6b0 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 14:08:22 -0600 Subject: [PATCH 08/12] test: add unit tests for datagen modules - test_types.py: JointType enum values/count, dataclass defaults/isolation - test_pebble_game.py: DOF accounting, rigidity, classification, edge results - test_jacobian.py: Jacobian shape per joint type, rank, parallel axis degeneracy - test_analysis.py: demo scenarios (revolute, fixed, triangle, parallel axes) - test_generator.py: chain/rigid/overconstrained generation, training batch Bug fixes found during testing: - JointType enum: duplicate int values caused aliasing (SLIDER=REVOLUTE etc). Changed to (ordinal, dof) tuple values with a .dof property. - pebble_game.py: .value -> .dof for constraint count - analysis.py: classify from effective DOF (not raw pebble game with virtual ground body skew) 105 tests, all passing. Closes #6 --- solver/datagen/analysis.py | 12 +- solver/datagen/pebble_game.py | 2 +- solver/datagen/types.py | 31 ++-- tests/datagen/__init__.py | 0 tests/datagen/test_analysis.py | 240 +++++++++++++++++++++++++++ tests/datagen/test_generator.py | 166 +++++++++++++++++++ tests/datagen/test_jacobian.py | 267 ++++++++++++++++++++++++++++++ tests/datagen/test_pebble_game.py | 206 +++++++++++++++++++++++ tests/datagen/test_types.py | 163 ++++++++++++++++++ 9 files changed, 1073 insertions(+), 14 deletions(-) create mode 100644 tests/datagen/__init__.py create mode 100644 tests/datagen/test_analysis.py create mode 100644 tests/datagen/test_generator.py create mode 100644 tests/datagen/test_jacobian.py create mode 100644 tests/datagen/test_pebble_game.py create mode 100644 tests/datagen/test_types.py diff --git a/solver/datagen/analysis.py b/solver/datagen/analysis.py index f862802..6ad8491 100644 --- a/solver/datagen/analysis.py +++ b/solver/datagen/analysis.py @@ -78,7 +78,17 @@ def analyze_assembly( effective_dof = raw_dof - ground_offset effective_internal_dof = max(0, effective_dof - (0 if grounded else 6)) - combinatorial_classification = pg.classify_assembly(grounded=grounded) + # Classify based on effective (adjusted) DOF, not raw pebble game output, + # because the virtual ground body skews the raw numbers. + redundant = pg.get_redundant_count() + if redundant > 0 and effective_internal_dof > 0: + combinatorial_classification = "mixed" + elif redundant > 0: + combinatorial_classification = "overconstrained" + elif effective_internal_dof > 0: + combinatorial_classification = "underconstrained" + else: + combinatorial_classification = "well-constrained" # --- Jacobian Verification --- verifier = JacobianVerifier(bodies) diff --git a/solver/datagen/pebble_game.py b/solver/datagen/pebble_game.py index 655cd13..e3bdbc3 100644 --- a/solver/datagen/pebble_game.py +++ b/solver/datagen/pebble_game.py @@ -79,7 +79,7 @@ class PebbleGame3D: self.add_body(joint.body_a) self.add_body(joint.body_b) - num_constraints = joint.joint_type.value + num_constraints = joint.joint_type.dof results: list[dict[str, Any]] = [] for i in range(num_constraints): diff --git a/solver/datagen/types.py b/solver/datagen/types.py index 754e5dc..f7b3e84 100644 --- a/solver/datagen/types.py +++ b/solver/datagen/types.py @@ -37,20 +37,27 @@ class JointType(enum.Enum): representation, each joint maps to a number of edges equal to the DOF it removes. - DOF removed = number of scalar constraint equations the joint imposes. + Values are ``(ordinal, dof_removed)`` tuples so that joint types + sharing the same DOF count remain distinct enum members. Use the + :attr:`dof` property to get the scalar constraint count. """ - FIXED = 6 # Locks all relative motion - REVOLUTE = 5 # Allows rotation about one axis - CYLINDRICAL = 4 # Allows rotation + translation along one axis - SLIDER = 5 # Allows translation along one axis (prismatic) - BALL = 3 # Allows rotation about a point (spherical) - PLANAR = 3 # Allows 2D translation + rotation normal to plane - SCREW = 5 # Coupled rotation-translation (helical) - UNIVERSAL = 4 # Two rotational DOF (Cardan/U-joint) - PARALLEL = 3 # Forces parallel orientation (3 rotation constraints) - PERPENDICULAR = 1 # Single angular constraint - DISTANCE = 1 # Single scalar distance constraint + FIXED = (0, 6) # Locks all relative motion + REVOLUTE = (1, 5) # Allows rotation about one axis + CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis + SLIDER = (3, 5) # Allows translation along one axis (prismatic) + BALL = (4, 3) # Allows rotation about a point (spherical) + PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane + SCREW = (6, 5) # Coupled rotation-translation (helical) + UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint) + PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints) + PERPENDICULAR = (9, 1) # Single angular constraint + DISTANCE = (10, 1) # Single scalar distance constraint + + @property + def dof(self) -> int: + """Number of scalar constraints (DOF removed) by this joint type.""" + return self.value[1] # --------------------------------------------------------------------------- diff --git a/tests/datagen/__init__.py b/tests/datagen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/datagen/test_analysis.py b/tests/datagen/test_analysis.py new file mode 100644 index 0000000..8edf893 --- /dev/null +++ b/tests/datagen/test_analysis.py @@ -0,0 +1,240 @@ +"""Tests for solver.datagen.analysis -- combined analysis function.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from solver.datagen.analysis import analyze_assembly +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _two_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + ] + + +def _triangle_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([1.0, 1.7, 0.0])), + ] + + +# --------------------------------------------------------------------------- +# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF) +# --------------------------------------------------------------------------- + + +class TestTwoBodiesRevolute: + """Demo scenario 1: two bodies connected by a revolute joint.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _two_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_internal_dof(self, result: ConstraintAnalysis) -> None: + assert result.jacobian_internal_dof == 1 + + def test_not_rigid(self, result: ConstraintAnalysis) -> None: + assert not result.is_rigid + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification == "underconstrained" + + def test_no_redundant(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_redundant == 0 + + +# --------------------------------------------------------------------------- +# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF) +# --------------------------------------------------------------------------- + + +class TestTwoBodiesFixed: + """Demo scenario 2: two bodies connected by a fixed joint.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _two_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_internal_dof(self, result: ConstraintAnalysis) -> None: + assert result.jacobian_internal_dof == 0 + + def test_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_rigid + + def test_minimally_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_minimally_rigid + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification == "well-constrained" + + +# --------------------------------------------------------------------------- +# Scenario 3: Triangle with revolute joints (overconstrained) +# --------------------------------------------------------------------------- + + +class TestTriangleRevolute: + """Demo scenario 3: triangle of 3 bodies + 3 revolute joints.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = _triangle_bodies() + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 1, + body_a=1, + body_b=2, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.5, 0.85, 0.0]), + anchor_b=np.array([1.5, 0.85, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 2, + body_a=2, + body_b=0, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([0.5, 0.85, 0.0]), + anchor_b=np.array([0.5, 0.85, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_has_redundant(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_redundant > 0 + + def test_classification(self, result: ConstraintAnalysis) -> None: + assert result.combinatorial_classification in ("overconstrained", "mixed") + + def test_rigid(self, result: ConstraintAnalysis) -> None: + assert result.is_rigid + + def test_numerically_dependent(self, result: ConstraintAnalysis) -> None: + assert len(result.numerically_dependent) > 0 + + +# --------------------------------------------------------------------------- +# Scenario 4: Parallel revolute axes (geometric degeneracy) +# --------------------------------------------------------------------------- + + +class TestParallelRevoluteAxes: + """Demo scenario 4: parallel revolute axes create geometric degeneracies.""" + + @pytest.fixture() + def result(self) -> ConstraintAnalysis: + bodies = [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([4.0, 0.0, 0.0])), + ] + joints = [ + Joint( + 0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + Joint( + 1, + body_a=1, + body_b=2, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([3.0, 0.0, 0.0]), + anchor_b=np.array([3.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ), + ] + return analyze_assembly(bodies, joints, ground_body=0) + + def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None: + """Parallel axes produce at least one geometric degeneracy.""" + assert result.geometric_degeneracies > 0 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestNoJoints: + """Assembly with bodies but no joints.""" + + def test_all_dof_free(self) -> None: + bodies = _two_bodies() + result = analyze_assembly(bodies, [], ground_body=0) + # Body 1 is completely free (6 DOF), body 0 is grounded + assert result.jacobian_internal_dof > 0 + assert not result.is_rigid + + def test_ungrounded(self) -> None: + bodies = _two_bodies() + result = analyze_assembly(bodies, []) + assert result.combinatorial_classification == "underconstrained" + + +class TestReturnType: + """Verify the return object is a proper ConstraintAnalysis.""" + + def test_instance(self) -> None: + bodies = _two_bodies() + joints = [Joint(0, 0, 1, JointType.FIXED)] + result = analyze_assembly(bodies, joints) + assert isinstance(result, ConstraintAnalysis) + + def test_per_edge_results_populated(self) -> None: + bodies = _two_bodies() + joints = [Joint(0, 0, 1, JointType.REVOLUTE)] + result = analyze_assembly(bodies, joints) + assert len(result.per_edge_results) == 5 diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py new file mode 100644 index 0000000..06c10a5 --- /dev/null +++ b/tests/datagen/test_generator.py @@ -0,0 +1,166 @@ +"""Tests for solver.datagen.generator -- synthetic assembly generation.""" + +from __future__ import annotations + +import pytest + +from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.types import JointType + + +class TestChainAssembly: + """generate_chain_assembly produces valid underconstrained chains.""" + + def test_returns_three_tuple(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + bodies, joints, _analysis = gen.generate_chain_assembly(4) + assert len(bodies) == 4 + assert len(joints) == 3 + + def test_chain_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, _, analysis = gen.generate_chain_assembly(4) + assert analysis.combinatorial_classification == "underconstrained" + + def test_chain_body_ids(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + bodies, _, _ = gen.generate_chain_assembly(5) + ids = [b.body_id for b in bodies] + assert ids == [0, 1, 2, 3, 4] + + def test_chain_joint_connectivity(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, joints, _ = gen.generate_chain_assembly(4) + for i, j in enumerate(joints): + assert j.body_a == i + assert j.body_b == i + 1 + + def test_chain_custom_joint_type(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + _, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL) + assert all(j.joint_type is JointType.BALL for j in joints) + + +class TestRigidAssembly: + """generate_rigid_assembly produces rigid assemblies.""" + + def test_rigid(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_rigid_assembly(4) + assert analysis.is_rigid + + def test_spanning_tree_structure(self) -> None: + """n bodies should have at least n-1 joints (spanning tree).""" + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_rigid_assembly(5) + assert len(joints) >= len(bodies) - 1 + + def test_deterministic(self) -> None: + """Same seed produces same results.""" + g1 = SyntheticAssemblyGenerator(seed=99) + g2 = SyntheticAssemblyGenerator(seed=99) + _, j1, a1 = g1.generate_rigid_assembly(4) + _, j2, a2 = g2.generate_rigid_assembly(4) + assert a1.jacobian_rank == a2.jacobian_rank + assert len(j1) == len(j2) + + +class TestOverconstrainedAssembly: + """generate_overconstrained_assembly adds redundant constraints.""" + + def test_has_redundant(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2) + assert analysis.combinatorial_redundant > 0 + + def test_extra_joints_added(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints_base, _ = gen.generate_rigid_assembly(4) + + gen2 = SyntheticAssemblyGenerator(seed=42) + _, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3) + # Overconstrained has base joints + extra + assert len(joints_over) > len(joints_base) + + +class TestTrainingBatch: + """generate_training_batch produces well-structured examples.""" + + @pytest.fixture() + def batch(self) -> list[dict]: + gen = SyntheticAssemblyGenerator(seed=42) + return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6)) + + def test_batch_size(self, batch: list[dict]) -> None: + assert len(batch) == 20 + + def test_example_keys(self, batch: list[dict]) -> None: + expected = { + "example_id", + "n_bodies", + "n_joints", + "body_positions", + "joints", + "joint_labels", + "assembly_classification", + "is_rigid", + "is_minimally_rigid", + "internal_dof", + "geometric_degeneracies", + } + for ex in batch: + assert set(ex.keys()) == expected + + def test_example_ids_sequential(self, batch: list[dict]) -> None: + ids = [ex["example_id"] for ex in batch] + assert ids == list(range(20)) + + def test_classification_distribution(self, batch: list[dict]) -> None: + """Batch should contain multiple classification types.""" + classes = {ex["assembly_classification"] for ex in batch} + # With the 3-way generator split we expect at least 2 types + assert len(classes) >= 2 + + def test_body_count_in_range(self, batch: list[dict]) -> None: + for ex in batch: + assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6) + + def test_joint_labels_match_joints(self, batch: list[dict]) -> None: + for ex in batch: + label_jids = set(ex["joint_labels"].keys()) + joint_jids = {j["joint_id"] for j in ex["joints"]} + assert label_jids == joint_jids + + def test_joint_label_fields(self, batch: list[dict]) -> None: + expected_fields = { + "independent_constraints", + "redundant_constraints", + "total_constraints", + } + for ex in batch: + for label in ex["joint_labels"].values(): + assert set(label.keys()) == expected_fields + + def test_joint_label_consistency(self, batch: list[dict]) -> None: + """independent + redundant == total for every joint.""" + for ex in batch: + for label in ex["joint_labels"].values(): + total = label["independent_constraints"] + label["redundant_constraints"] + assert total == label["total_constraints"] + + +class TestSeedReproducibility: + """Different seeds produce different results.""" + + def test_different_seeds_differ(self) -> None: + g1 = SyntheticAssemblyGenerator(seed=1) + g2 = SyntheticAssemblyGenerator(seed=2) + b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) + b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) + # Very unlikely to be identical with different seeds + c1 = [ex["assembly_classification"] for ex in b1] + c2 = [ex["assembly_classification"] for ex in b2] + r1 = [ex["is_rigid"] for ex in b1] + r2 = [ex["is_rigid"] for ex in b2] + # At least one of these should differ (probabilistically certain) + assert c1 != c2 or r1 != r2 diff --git a/tests/datagen/test_jacobian.py b/tests/datagen/test_jacobian.py new file mode 100644 index 0000000..864331d --- /dev/null +++ b/tests/datagen/test_jacobian.py @@ -0,0 +1,267 @@ +"""Tests for solver.datagen.jacobian -- Jacobian rank verification.""" + +from __future__ import annotations + +import numpy as np +import pytest + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.types import Joint, JointType, RigidBody + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + + +def _two_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + ] + + +def _three_bodies() -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([4.0, 0.0, 0.0])), + ] + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestJacobianShape: + """Verify Jacobian matrix dimensions for each joint type.""" + + @pytest.mark.parametrize( + "joint_type,expected_rows", + [ + (JointType.FIXED, 6), + (JointType.REVOLUTE, 5), + (JointType.CYLINDRICAL, 4), + (JointType.SLIDER, 5), + (JointType.BALL, 3), + (JointType.PLANAR, 3), + (JointType.SCREW, 5), + (JointType.UNIVERSAL, 4), + (JointType.PARALLEL, 3), + (JointType.PERPENDICULAR, 1), + (JointType.DISTANCE, 1), + ], + ) + def test_row_count(self, joint_type: JointType, expected_rows: int) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + joint = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=joint_type, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + n_added = v.add_joint_constraints(joint) + assert n_added == expected_rows + + j = v.get_jacobian() + assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols + + +class TestNumericalRank: + """Numerical rank checks for known configurations.""" + + def test_fixed_joint_rank_six(self) -> None: + """Fixed joint between 2 bodies: rank = 6.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 6 + + def test_revolute_joint_rank_five(self) -> None: + """Revolute joint between 2 bodies: rank = 5.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 5 + + def test_ball_joint_rank_three(self) -> None: + """Ball joint between 2 bodies: rank = 3.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.BALL, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.numerical_rank() == 3 + + def test_empty_jacobian_rank_zero(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + assert v.numerical_rank() == 0 + + +class TestParallelAxesDegeneracy: + """Parallel revolute axes create geometric dependencies.""" + + def _four_body_loop(self) -> list[RigidBody]: + return [ + RigidBody(0, position=np.array([0.0, 0.0, 0.0])), + RigidBody(1, position=np.array([2.0, 0.0, 0.0])), + RigidBody(2, position=np.array([2.0, 2.0, 0.0])), + RigidBody(3, position=np.array([0.0, 2.0, 0.0])), + ] + + def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]: + pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])] + return [ + Joint( + joint_id=i, + body_a=a, + body_b=b, + joint_type=JointType.REVOLUTE, + anchor_a=np.array(anc, dtype=float), + anchor_b=np.array(anc, dtype=float), + axis=axes[i], + ) + for i, (a, b, anc) in enumerate(pairs) + ] + + def test_parallel_has_lower_rank(self) -> None: + """4-body closed loop: all-parallel revolute axes produce lower + Jacobian rank than mixed axes due to geometric dependency.""" + bodies = self._four_body_loop() + z_axis = np.array([0.0, 0.0, 1.0]) + + # All axes parallel to Z + v_par = JacobianVerifier(bodies) + for j in self._loop_joints([z_axis] * 4): + v_par.add_joint_constraints(j) + rank_par = v_par.numerical_rank() + + # Mixed axes + mixed = [ + np.array([0.0, 0.0, 1.0]), + np.array([0.0, 1.0, 0.0]), + np.array([0.0, 0.0, 1.0]), + np.array([1.0, 0.0, 0.0]), + ] + v_mix = JacobianVerifier(bodies) + for j in self._loop_joints(mixed): + v_mix.add_joint_constraints(j) + rank_mix = v_mix.numerical_rank() + + assert rank_par < rank_mix + + +class TestFindDependencies: + """Dependency detection.""" + + def test_fixed_joint_no_dependencies(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=0, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + v.add_joint_constraints(j) + assert v.find_dependencies() == [] + + def test_duplicate_fixed_has_dependencies(self) -> None: + """Two fixed joints on same pair: second is fully dependent.""" + bodies = _two_bodies() + v = JacobianVerifier(bodies) + for jid in range(2): + v.add_joint_constraints( + Joint( + joint_id=jid, + body_a=0, + body_b=1, + joint_type=JointType.FIXED, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + ) + ) + deps = v.find_dependencies() + assert len(deps) == 6 # Second fixed joint entirely redundant + + def test_empty_no_dependencies(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + assert v.find_dependencies() == [] + + +class TestRowLabels: + """Row label metadata.""" + + def test_labels_match_rows(self) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + j = Joint( + joint_id=7, + body_a=0, + body_b=1, + joint_type=JointType.REVOLUTE, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([1.0, 0.0, 0.0]), + axis=np.array([0.0, 0.0, 1.0]), + ) + v.add_joint_constraints(j) + assert len(v.row_labels) == 5 + assert all(lab["joint_id"] == 7 for lab in v.row_labels) + + +class TestPerpendicularPair: + """Internal _perpendicular_pair utility.""" + + @pytest.mark.parametrize( + "axis", + [ + np.array([1.0, 0.0, 0.0]), + np.array([0.0, 1.0, 0.0]), + np.array([0.0, 0.0, 1.0]), + np.array([1.0, 1.0, 1.0]) / np.sqrt(3), + ], + ) + def test_orthonormal(self, axis: np.ndarray) -> None: + bodies = _two_bodies() + v = JacobianVerifier(bodies) + t1, t2 = v._perpendicular_pair(axis) + + # All unit length + np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12) + np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12) + + # Mutually perpendicular + np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12) + np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12) + np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12) diff --git a/tests/datagen/test_pebble_game.py b/tests/datagen/test_pebble_game.py new file mode 100644 index 0000000..da71eab --- /dev/null +++ b/tests/datagen/test_pebble_game.py @@ -0,0 +1,206 @@ +"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game.""" + +from __future__ import annotations + +import itertools + +import numpy as np +import pytest + +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import Joint, JointType + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint: + """Shorthand for a revolute joint between bodies *a* and *b*.""" + if axis is None: + axis = np.array([0.0, 0.0, 1.0]) + return Joint( + joint_id=jid, + body_a=a, + body_b=b, + joint_type=JointType.REVOLUTE, + axis=axis, + ) + + +def _fixed(jid: int, a: int, b: int) -> Joint: + return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED) + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + + +class TestAddBody: + """Body registration basics.""" + + def test_single_body_six_pebbles(self) -> None: + pg = PebbleGame3D() + pg.add_body(0) + assert pg.state.free_pebbles[0] == 6 + + def test_duplicate_body_no_op(self) -> None: + pg = PebbleGame3D() + pg.add_body(0) + pg.add_body(0) + assert pg.state.free_pebbles[0] == 6 + + def test_multiple_bodies(self) -> None: + pg = PebbleGame3D() + for i in range(5): + pg.add_body(i) + assert pg.get_dof() == 30 # 5 * 6 + + +class TestAddJoint: + """Joint insertion and DOF accounting.""" + + def test_revolute_removes_five_dof(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + assert len(results) == 5 # 5 scalar constraints + assert all(r["independent"] for r in results) + # 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles + assert pg.get_dof() == 7 + + def test_fixed_removes_six_dof(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_fixed(0, 0, 1)) + assert len(results) == 6 + assert all(r["independent"] for r in results) + assert pg.get_dof() == 6 + + def test_ball_removes_three_dof(self) -> None: + pg = PebbleGame3D() + j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL) + results = pg.add_joint(j) + assert len(results) == 3 + assert all(r["independent"] for r in results) + assert pg.get_dof() == 9 + + +class TestTwoBodiesRevolute: + """Two bodies connected by a revolute -- demo scenario 1.""" + + def test_internal_dof(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + # Total DOF = 7, internal = 7 - 6 = 1 + assert pg.get_internal_dof() == 1 + + def test_not_rigid(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + assert not pg.is_rigid() + + def test_classification(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + assert pg.classify_assembly() == "underconstrained" + + +class TestTwoBodiesFixed: + """Two bodies + fixed joint -- demo scenario 2.""" + + def test_zero_internal_dof(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.get_internal_dof() == 0 + + def test_rigid(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.is_rigid() + + def test_well_constrained(self) -> None: + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + assert pg.classify_assembly() == "well-constrained" + + +class TestTriangleRevolute: + """Triangle of 3 bodies with revolute joints -- demo scenario 3.""" + + @pytest.fixture() + def pg(self) -> PebbleGame3D: + pg = PebbleGame3D() + pg.add_joint(_revolute(0, 0, 1)) + pg.add_joint(_revolute(1, 1, 2)) + pg.add_joint(_revolute(2, 2, 0)) + return pg + + def test_has_redundant_edges(self, pg: PebbleGame3D) -> None: + assert pg.get_redundant_count() > 0 + + def test_classification_overconstrained(self, pg: PebbleGame3D) -> None: + # 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed) + assert pg.classify_assembly() in ("overconstrained", "mixed") + + def test_rigid(self, pg: PebbleGame3D) -> None: + assert pg.is_rigid() + + +class TestChainNotRigid: + """A serial chain of 4 bodies with revolute joints is never rigid.""" + + def test_chain_underconstrained(self) -> None: + pg = PebbleGame3D() + for i in range(3): + pg.add_joint(_revolute(i, i, i + 1)) + assert not pg.is_rigid() + assert pg.classify_assembly() == "underconstrained" + + def test_chain_internal_dof(self) -> None: + pg = PebbleGame3D() + for i in range(3): + pg.add_joint(_revolute(i, i, i + 1)) + # 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3 + assert pg.get_internal_dof() == 3 + + +class TestEdgeResults: + """Result dicts returned by add_joint.""" + + def test_result_keys(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"} + for r in results: + assert set(r.keys()) == expected_keys + + def test_edge_ids_sequential(self) -> None: + pg = PebbleGame3D() + r1 = pg.add_joint(_revolute(0, 0, 1)) + r2 = pg.add_joint(_revolute(1, 1, 2)) + all_ids = [r["edge_id"] for r in r1 + r2] + assert all_ids == list(range(10)) + + def test_dof_remaining_monotonic(self) -> None: + pg = PebbleGame3D() + results = pg.add_joint(_revolute(0, 0, 1)) + dofs = [r["dof_remaining"] for r in results] + # Should be non-increasing (each independent edge removes a pebble) + for a, b in itertools.pairwise(dofs): + assert a >= b + + +class TestGroundedClassification: + """classify_assembly with grounded=True.""" + + def test_grounded_baseline_zero(self) -> None: + """With grounded=True the baseline is 0 (not 6).""" + pg = PebbleGame3D() + pg.add_joint(_fixed(0, 0, 1)) + # Ungrounded: well-constrained (6 pebbles = baseline 6) + assert pg.classify_assembly(grounded=False) == "well-constrained" + # Grounded: the 6 remaining pebbles on body 1 exceed baseline 0, + # so the raw pebble game (without a virtual ground body) sees this + # as underconstrained. The analysis function handles this properly + # by adding a virtual ground body. + assert pg.classify_assembly(grounded=True) == "underconstrained" diff --git a/tests/datagen/test_types.py b/tests/datagen/test_types.py new file mode 100644 index 0000000..7667243 --- /dev/null +++ b/tests/datagen/test_types.py @@ -0,0 +1,163 @@ +"""Tests for solver.datagen.types -- shared data types.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np +import pytest + +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + PebbleState, + RigidBody, +) + + +class TestJointType: + """JointType enum construction and DOF values.""" + + EXPECTED_DOF: ClassVar[dict[str, int]] = { + "FIXED": 6, + "REVOLUTE": 5, + "CYLINDRICAL": 4, + "SLIDER": 5, + "BALL": 3, + "PLANAR": 3, + "SCREW": 5, + "UNIVERSAL": 4, + "PARALLEL": 3, + "PERPENDICULAR": 1, + "DISTANCE": 1, + } + + def test_member_count(self) -> None: + assert len(JointType) == 11 + + @pytest.mark.parametrize("name,dof", EXPECTED_DOF.items()) + def test_dof_values(self, name: str, dof: int) -> None: + assert JointType[name].dof == dof + + def test_access_by_name(self) -> None: + assert JointType["REVOLUTE"] is JointType.REVOLUTE + + def test_value_is_tuple(self) -> None: + assert JointType.REVOLUTE.value == (1, 5) + assert JointType.REVOLUTE.dof == 5 + + +class TestRigidBody: + """RigidBody dataclass defaults and construction.""" + + def test_defaults(self) -> None: + body = RigidBody(body_id=0) + np.testing.assert_array_equal(body.position, np.zeros(3)) + np.testing.assert_array_equal(body.orientation, np.eye(3)) + assert body.local_anchors == {} + + def test_custom_position(self) -> None: + pos = np.array([1.0, 2.0, 3.0]) + body = RigidBody(body_id=7, position=pos) + np.testing.assert_array_equal(body.position, pos) + assert body.body_id == 7 + + def test_local_anchors_mutable(self) -> None: + body = RigidBody(body_id=0) + body.local_anchors["top"] = np.array([0.0, 0.0, 1.0]) + assert "top" in body.local_anchors + + def test_default_factory_isolation(self) -> None: + """Each instance gets its own default containers.""" + b1 = RigidBody(body_id=0) + b2 = RigidBody(body_id=1) + b1.local_anchors["x"] = np.zeros(3) + assert "x" not in b2.local_anchors + + +class TestJoint: + """Joint dataclass defaults and construction.""" + + def test_defaults(self) -> None: + j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE) + np.testing.assert_array_equal(j.anchor_a, np.zeros(3)) + np.testing.assert_array_equal(j.anchor_b, np.zeros(3)) + np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0])) + assert j.pitch == 0.0 + + def test_full_construction(self) -> None: + j = Joint( + joint_id=5, + body_a=2, + body_b=3, + joint_type=JointType.SCREW, + anchor_a=np.array([1.0, 0.0, 0.0]), + anchor_b=np.array([2.0, 0.0, 0.0]), + axis=np.array([1.0, 0.0, 0.0]), + pitch=0.5, + ) + assert j.joint_id == 5 + assert j.joint_type is JointType.SCREW + assert j.pitch == 0.5 + + +class TestPebbleState: + """PebbleState dataclass defaults.""" + + def test_defaults(self) -> None: + s = PebbleState() + assert s.free_pebbles == {} + assert s.directed_edges == {} + assert s.independent_edges == set() + assert s.redundant_edges == set() + assert s.incoming == {} + assert s.outgoing == {} + + def test_default_factory_isolation(self) -> None: + s1 = PebbleState() + s2 = PebbleState() + s1.free_pebbles[0] = 6 + assert 0 not in s2.free_pebbles + + +class TestConstraintAnalysis: + """ConstraintAnalysis dataclass construction.""" + + def test_construction(self) -> None: + ca = ConstraintAnalysis( + combinatorial_dof=6, + combinatorial_internal_dof=0, + combinatorial_redundant=0, + combinatorial_classification="well-constrained", + per_edge_results=[], + jacobian_rank=6, + jacobian_nullity=0, + jacobian_internal_dof=0, + numerically_dependent=[], + geometric_degeneracies=0, + is_rigid=True, + is_minimally_rigid=True, + ) + assert ca.is_rigid is True + assert ca.is_minimally_rigid is True + assert ca.combinatorial_classification == "well-constrained" + + def test_per_edge_results_typing(self) -> None: + """per_edge_results accepts list[dict[str, Any]].""" + ca = ConstraintAnalysis( + combinatorial_dof=7, + combinatorial_internal_dof=1, + combinatorial_redundant=0, + combinatorial_classification="underconstrained", + per_edge_results=[{"edge_id": 0, "independent": True}], + jacobian_rank=5, + jacobian_nullity=1, + jacobian_internal_dof=1, + numerically_dependent=[], + geometric_degeneracies=0, + is_rigid=False, + is_minimally_rigid=False, + ) + assert len(ca.per_edge_results) == 1 + assert ca.per_edge_results[0]["edge_id"] == 0 From 0b5813b5a9d2c67b0580cd596ee84f0b9c301a2d Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 14:38:05 -0600 Subject: [PATCH 09/12] feat: parameterized assembly templates and complexity tiers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add 4 new topology generators to SyntheticAssemblyGenerator: - generate_tree_assembly: random spanning tree with configurable branching - generate_loop_assembly: closed ring producing overconstrained data - generate_star_assembly: hub-and-spoke topology - generate_mixed_assembly: tree + loops with configurable edge density Each accepts joint_types as JointType | list[JointType] for per-joint type sampling. Add complexity tiers (simple/medium/complex) with predefined body count ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias. Update generate_training_batch with 7-way generator selection, complexity_tier parameter, and generator_type field in output dicts. Extract private helpers (_random_position, _random_axis, _select_joint_type, _create_joint) to reduce duplication. 44 generator tests, 130 total — all passing. Closes #7 --- solver/datagen/__init__.py | 3 +- solver/datagen/generator.py | 358 +++++++++++++++++++++++++++++--- tests/datagen/test_generator.py | 305 ++++++++++++++++++++++----- 3 files changed, 590 insertions(+), 76 deletions(-) diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index bd6fc3c..203adcc 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,7 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly -from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -13,6 +13,7 @@ from solver.datagen.types import ( ) __all__ = [ + "COMPLEXITY_RANGES", "ConstraintAnalysis", "JacobianVerifier", "Joint", diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index 6c6bb0f..cd90ade 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -7,7 +7,7 @@ with per-constraint independence flags and assembly-level classification. from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Literal import numpy as np @@ -22,7 +22,19 @@ from solver.datagen.types import ( if TYPE_CHECKING: from typing import Any -__all__ = ["SyntheticAssemblyGenerator"] +__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"] + +# --------------------------------------------------------------------------- +# Complexity tiers — ranges use exclusive upper bound for rng.integers() +# --------------------------------------------------------------------------- + +ComplexityTier = Literal["simple", "medium", "complex"] + +COMPLEXITY_RANGES: dict[str, tuple[int, int]] = { + "simple": (2, 6), + "medium": (6, 16), + "complex": (16, 51), +} class SyntheticAssemblyGenerator: @@ -41,6 +53,55 @@ class SyntheticAssemblyGenerator: def __init__(self, seed: int = 42) -> None: self.rng = np.random.default_rng(seed) + # ------------------------------------------------------------------ + # Private helpers + # ------------------------------------------------------------------ + + def _random_position(self, scale: float = 5.0) -> np.ndarray: + """Generate random 3D position within [-scale, scale] cube.""" + return self.rng.uniform(-scale, scale, size=3) + + def _random_axis(self) -> np.ndarray: + """Generate random normalized 3D axis.""" + axis = self.rng.standard_normal(3) + axis /= np.linalg.norm(axis) + return axis + + def _select_joint_type( + self, + joint_types: JointType | list[JointType], + ) -> JointType: + """Select a joint type from a single type or list.""" + if isinstance(joint_types, list): + idx = int(self.rng.integers(0, len(joint_types))) + return joint_types[idx] + return joint_types + + def _create_joint( + self, + joint_id: int, + body_a_id: int, + body_b_id: int, + pos_a: np.ndarray, + pos_b: np.ndarray, + joint_type: JointType, + ) -> Joint: + """Create a joint between two bodies with random axis at midpoint.""" + anchor = (pos_a + pos_b) / 2.0 + return Joint( + joint_id=joint_id, + body_a=body_a_id, + body_b=body_b_id, + joint_type=joint_type, + anchor_a=anchor, + anchor_b=anchor, + axis=self._random_axis(), + ) + + # ------------------------------------------------------------------ + # Original generators (chain / rigid / overconstrained) + # ------------------------------------------------------------------ + def generate_chain_assembly( self, n_bodies: int, @@ -60,11 +121,7 @@ class SyntheticAssemblyGenerator: bodies.append(RigidBody(body_id=i, position=pos)) for i in range(n_bodies - 1): - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) - anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) - joints.append( Joint( joint_id=i, @@ -73,7 +130,7 @@ class SyntheticAssemblyGenerator: joint_type=joint_type, anchor_a=anchor, anchor_b=anchor, - axis=axis, + axis=self._random_axis(), ) ) @@ -91,26 +148,22 @@ class SyntheticAssemblyGenerator: """ bodies = [] for i in range(n_bodies): - pos = self.rng.uniform(-5, 5, size=3) - bodies.append(RigidBody(body_id=i, position=pos)) + bodies.append(RigidBody(body_id=i, position=self._random_position())) # Build spanning tree with fixed joints (overconstrained) joints: list[Joint] = [] for i in range(1, n_bodies): - parent = self.rng.integers(0, i) + parent = int(self.rng.integers(0, i)) mid = (bodies[i].position + bodies[parent].position) / 2 - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) - joints.append( Joint( joint_id=i - 1, - body_a=int(parent), + body_a=parent, body_b=i, joint_type=JointType.FIXED, anchor_a=mid, anchor_b=mid, - axis=axis, + axis=self._random_axis(), ) ) @@ -150,8 +203,6 @@ class SyntheticAssemblyGenerator: for _ in range(extra_joints): a, b = self.rng.choice(n_bodies, size=2, replace=False) mid = (bodies[a].position + bodies[b].position) / 2 - axis = self.rng.standard_normal(3) - axis /= np.linalg.norm(axis) _overcon_types = [ JointType.REVOLUTE, @@ -167,7 +218,7 @@ class SyntheticAssemblyGenerator: joint_type=jtype, anchor_a=mid, anchor_b=mid, - axis=axis, + axis=self._random_axis(), ) ) joint_id += 1 @@ -175,24 +226,259 @@ class SyntheticAssemblyGenerator: analysis = analyze_assembly(bodies, joints, ground_body=0) return bodies, joints, analysis + # ------------------------------------------------------------------ + # New topology generators + # ------------------------------------------------------------------ + + def generate_tree_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + branching_factor: int = 3, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a random tree topology with configurable branching. + + Creates a tree where each body can have up to *branching_factor* + children. Different branches can use different joint types if a + list is provided. Always underconstrained (no closed loops). + + Args: + n_bodies: Total bodies (root + children). + joint_types: Single type or list to sample from per joint. + branching_factor: Max children per parent (1-5 recommended). + """ + bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + joints: list[Joint] = [] + + available_parents = [0] + next_id = 1 + joint_id = 0 + + while next_id < n_bodies and available_parents: + pidx = int(self.rng.integers(0, len(available_parents))) + parent_id = available_parents[pidx] + parent_pos = bodies[parent_id].position + + max_children = min(branching_factor, n_bodies - next_id) + n_children = int(self.rng.integers(1, max_children + 1)) + + for _ in range(n_children): + direction = self._random_axis() + distance = self.rng.uniform(1.5, 3.0) + child_pos = parent_pos + direction * distance + + bodies.append(RigidBody(body_id=next_id, position=child_pos)) + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype) + ) + + available_parents.append(next_id) + next_id += 1 + joint_id += 1 + if next_id >= n_bodies: + break + + # Retire parent if it reached branching limit or randomly + if n_children >= branching_factor or self.rng.random() < 0.3: + available_parents.pop(pidx) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_loop_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a single closed loop (ring) of bodies. + + The closing constraint introduces redundancy, making this + useful for generating overconstrained training data. + + Args: + n_bodies: Bodies in the ring (>= 3). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 3. + """ + if n_bodies < 3: + msg = "Loop assembly requires at least 3 bodies" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + base_radius = max(2.0, n_bodies * 0.4) + for i in range(n_bodies): + angle = 2 * np.pi * i / n_bodies + radius = base_radius + self.rng.uniform(-0.5, 0.5) + x = radius * np.cos(angle) + y = radius * np.sin(angle) + z = float(self.rng.uniform(-0.2, 0.2)) + bodies.append(RigidBody(body_id=i, position=np.array([x, y, z]))) + + for i in range(n_bodies): + next_i = (i + 1) % n_bodies + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype) + ) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_star_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a star topology with central hub and satellites. + + Body 0 is the hub; all other bodies connect directly to it. + Underconstrained because there are no inter-satellite connections. + + Args: + n_bodies: Total bodies including hub (>= 2). + joint_types: Single type or list to sample from per joint. + + Raises: + ValueError: If *n_bodies* < 2. + """ + if n_bodies < 2: + msg = "Star assembly requires at least 2 bodies" + raise ValueError(msg) + + bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + joints: list[Joint] = [] + + for i in range(1, n_bodies): + direction = self._random_axis() + distance = self.rng.uniform(2.0, 5.0) + pos = direction * distance + bodies.append(RigidBody(body_id=i, position=pos)) + + jtype = self._select_joint_type(joint_types) + joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype)) + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + def generate_mixed_assembly( + self, + n_bodies: int, + joint_types: JointType | list[JointType] = JointType.REVOLUTE, + edge_density: float = 0.3, + ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: + """Generate a mixed topology combining tree and loop elements. + + Builds a spanning tree for connectivity, then adds extra edges + based on *edge_density* to create loops and redundancy. + + Args: + n_bodies: Number of bodies. + joint_types: Single type or list to sample from per joint. + edge_density: Fraction of non-tree edges to add (0.0-1.0). + + Raises: + ValueError: If *edge_density* not in [0.0, 1.0]. + """ + if not 0.0 <= edge_density <= 1.0: + msg = "edge_density must be in [0.0, 1.0]" + raise ValueError(msg) + + bodies: list[RigidBody] = [] + joints: list[Joint] = [] + + for i in range(n_bodies): + bodies.append(RigidBody(body_id=i, position=self._random_position())) + + # Phase 1: spanning tree + joint_id = 0 + existing_edges: set[frozenset[int]] = set() + for i in range(1, n_bodies): + parent = int(self.rng.integers(0, i)) + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint( + joint_id, + parent, + i, + bodies[parent].position, + bodies[i].position, + jtype, + ) + ) + existing_edges.add(frozenset([parent, i])) + joint_id += 1 + + # Phase 2: extra edges based on density + candidates: list[tuple[int, int]] = [] + for i in range(n_bodies): + for j in range(i + 1, n_bodies): + if frozenset([i, j]) not in existing_edges: + candidates.append((i, j)) + + n_extra = int(edge_density * len(candidates)) + self.rng.shuffle(candidates) + + for a, b in candidates[:n_extra]: + jtype = self._select_joint_type(joint_types) + joints.append( + self._create_joint( + joint_id, + a, + b, + bodies[a].position, + bodies[b].position, + jtype, + ) + ) + joint_id += 1 + + analysis = analyze_assembly(bodies, joints, ground_body=0) + return bodies, joints, analysis + + # ------------------------------------------------------------------ + # Batch generation + # ------------------------------------------------------------------ + def generate_training_batch( self, batch_size: int = 100, - n_bodies_range: tuple[int, int] = (3, 8), + n_bodies_range: tuple[int, int] | None = None, + complexity_tier: ComplexityTier | None = None, ) -> list[dict[str, Any]]: """Generate a batch of labeled training examples. - Each example contains: - - bodies: list of body positions/orientations - - joints: list of joints with types and parameters - - labels: per-joint independence/redundancy flags - - assembly_label: overall classification + Each example contains body positions, joint descriptions, + per-joint independence labels, and assembly-level classification. + + Args: + batch_size: Number of assemblies to generate. + n_bodies_range: ``(min, max_exclusive)`` body count. + Overridden by *complexity_tier* when both are given. + complexity_tier: Predefined range (``"simple"`` / ``"medium"`` + / ``"complex"``). Overrides *n_bodies_range*. """ + if complexity_tier is not None: + n_bodies_range = COMPLEXITY_RANGES[complexity_tier] + elif n_bodies_range is None: + n_bodies_range = (3, 8) + + _joint_pool = [ + JointType.REVOLUTE, + JointType.BALL, + JointType.CYLINDRICAL, + JointType.FIXED, + ] + examples: list[dict[str, Any]] = [] for i in range(batch_size): n = int(self.rng.integers(*n_bodies_range)) - gen_idx = int(self.rng.integers(3)) + gen_idx = int(self.rng.integers(7)) if gen_idx == 0: _chain_types = [ @@ -202,11 +488,30 @@ class SyntheticAssemblyGenerator: ] jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + gen_name = "chain" elif gen_idx == 1: bodies, joints, analysis = self.generate_rigid_assembly(n) - else: + gen_name = "rigid" + elif gen_idx == 2: extra = int(self.rng.integers(1, 4)) bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + gen_name = "overconstrained" + elif gen_idx == 3: + branching = int(self.rng.integers(2, 5)) + bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching) + gen_name = "tree" + elif gen_idx == 4: + n = max(n, 3) + bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool) + gen_name = "loop" + elif gen_idx == 5: + n = max(n, 2) + bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool) + gen_name = "star" + else: + density = float(self.rng.uniform(0.2, 0.5)) + bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density) + gen_name = "mixed" # Build per-joint labels from edge results joint_labels: dict[int, dict[str, int]] = {} @@ -227,6 +532,7 @@ class SyntheticAssemblyGenerator: examples.append( { "example_id": i, + "generator_type": gen_name, "n_bodies": len(bodies), "n_joints": len(joints), "body_positions": [b.position.tolist() for b in bodies], diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py index 06c10a5..fa5d70f 100644 --- a/tests/datagen/test_generator.py +++ b/tests/datagen/test_generator.py @@ -2,11 +2,18 @@ from __future__ import annotations +from typing import ClassVar + +import numpy as np import pytest -from solver.datagen.generator import SyntheticAssemblyGenerator +from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator from solver.datagen.types import JointType +# --------------------------------------------------------------------------- +# Original generators (chain / rigid / overconstrained) +# --------------------------------------------------------------------------- + class TestChainAssembly: """generate_chain_assembly produces valid underconstrained chains.""" @@ -83,66 +90,259 @@ class TestOverconstrainedAssembly: assert len(joints_over) > len(joints_base) +# --------------------------------------------------------------------------- +# New topology generators +# --------------------------------------------------------------------------- + + +class TestTreeAssembly: + """generate_tree_assembly produces tree-structured assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_tree_assembly(6) + assert analysis.combinatorial_classification == "underconstrained" + + def test_branching_factor(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2) + assert len(bodies) == 10 + assert len(joints) == 9 + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED] + _, joints, _ = gen.generate_tree_assembly(10, joint_types=types) + used = {j.joint_type for j in joints} + # With 9 joints and 3 types, very likely to use at least 2 + assert len(used) >= 2 + + def test_single_joint_type(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL) + assert all(j.joint_type is JointType.BALL for j in joints) + + def test_sequential_body_ids(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(7) + assert [b.body_id for b in bodies] == list(range(7)) + + +class TestLoopAssembly: + """generate_loop_assembly produces closed-loop assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_loop_assembly(5) + assert len(bodies) == 5 + assert len(joints) == 5 # n joints for n bodies + + def test_has_redundancy(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_loop_assembly(5) + assert analysis.combinatorial_redundant > 0 + + def test_wrap_around_connectivity(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly(4) + edges = {(j.body_a, j.body_b) for j in joints} + assert (0, 1) in edges + assert (1, 2) in edges + assert (2, 3) in edges + assert (3, 0) in edges # wrap-around + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 3"): + gen.generate_loop_assembly(2) + + def test_mixed_joint_types(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + types = [JointType.REVOLUTE, JointType.FIXED] + _, joints, _ = gen.generate_loop_assembly(8, joint_types=types) + used = {j.joint_type for j in joints} + assert len(used) >= 2 + + +class TestStarAssembly: + """generate_star_assembly produces hub-and-spoke assemblies.""" + + def test_body_and_joint_counts(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_star_assembly(6) + assert len(bodies) == 6 + assert len(joints) == 5 # n - 1 + + def test_all_joints_connect_to_hub(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly(6) + for j in joints: + assert j.body_a == 0 or j.body_b == 0 + + def test_underconstrained(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_star_assembly(5) + assert analysis.combinatorial_classification == "underconstrained" + + def test_minimum_bodies_error(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="at least 2"): + gen.generate_star_assembly(1) + + def test_hub_at_origin(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(4) + np.testing.assert_array_equal(bodies[0].position, np.zeros(3)) + + +class TestMixedAssembly: + """generate_mixed_assembly produces tree+loop hybrid assemblies.""" + + def test_more_joints_than_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3) + assert len(joints) > len(bodies) - 1 + + def test_density_zero_is_tree(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0) + assert len(joints) == 4 # spanning tree only + + def test_density_validation(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=1.5) + with pytest.raises(ValueError, match="must be in"): + gen.generate_mixed_assembly(5, edge_density=-0.1) + + def test_no_duplicate_edges(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5) + edges = [frozenset([j.body_a, j.body_b]) for j in joints] + assert len(edges) == len(set(edges)) + + def test_high_density(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0) + # Fully connected: 5*(5-1)/2 = 10 edges + assert len(joints) == 10 + + +# --------------------------------------------------------------------------- +# Complexity tiers +# --------------------------------------------------------------------------- + + +class TestComplexityTiers: + """Complexity tier parameter on batch generation.""" + + def test_simple_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="simple") + lo, hi = COMPLEXITY_RANGES["simple"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_medium_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, complexity_tier="medium") + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_complex_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(3, complexity_tier="complex") + lo, hi = COMPLEXITY_RANGES["complex"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + def test_tier_overrides_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium") + lo, hi = COMPLEXITY_RANGES["medium"] + for ex in batch: + assert lo <= ex["n_bodies"] < hi + + +# --------------------------------------------------------------------------- +# Training batch +# --------------------------------------------------------------------------- + + class TestTrainingBatch: """generate_training_batch produces well-structured examples.""" - @pytest.fixture() - def batch(self) -> list[dict]: - gen = SyntheticAssemblyGenerator(seed=42) - return gen.generate_training_batch(batch_size=20, n_bodies_range=(3, 6)) + EXPECTED_KEYS: ClassVar[set[str]] = { + "example_id", + "generator_type", + "n_bodies", + "n_joints", + "body_positions", + "joints", + "joint_labels", + "assembly_classification", + "is_rigid", + "is_minimally_rigid", + "internal_dof", + "geometric_degeneracies", + } - def test_batch_size(self, batch: list[dict]) -> None: + VALID_GEN_TYPES: ClassVar[set[str]] = { + "chain", + "rigid", + "overconstrained", + "tree", + "loop", + "star", + "mixed", + } + + def test_batch_size(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20) assert len(batch) == 20 - def test_example_keys(self, batch: list[dict]) -> None: - expected = { - "example_id", - "n_bodies", - "n_joints", - "body_positions", - "joints", - "joint_labels", - "assembly_classification", - "is_rigid", - "is_minimally_rigid", - "internal_dof", - "geometric_degeneracies", - } + def test_example_keys(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) for ex in batch: - assert set(ex.keys()) == expected + assert set(ex.keys()) == self.EXPECTED_KEYS - def test_example_ids_sequential(self, batch: list[dict]) -> None: - ids = [ex["example_id"] for ex in batch] - assert ids == list(range(20)) + def test_example_ids_sequential(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(15) + assert [ex["example_id"] for ex in batch] == list(range(15)) - def test_classification_distribution(self, batch: list[dict]) -> None: - """Batch should contain multiple classification types.""" - classes = {ex["assembly_classification"] for ex in batch} - # With the 3-way generator split we expect at least 2 types - assert len(classes) >= 2 - - def test_body_count_in_range(self, batch: list[dict]) -> None: + def test_generator_type_valid(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(50) for ex in batch: - assert 3 <= ex["n_bodies"] <= 5 # range is [3, 6) + assert ex["generator_type"] in self.VALID_GEN_TYPES - def test_joint_labels_match_joints(self, batch: list[dict]) -> None: + def test_generator_type_diversity(self) -> None: + """100-sample batch should use at least 5 of 7 generator types.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100) + types = {ex["generator_type"] for ex in batch} + assert len(types) >= 5 + + def test_default_body_range(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) for ex in batch: - label_jids = set(ex["joint_labels"].keys()) - joint_jids = {j["joint_id"] for j in ex["joints"]} - assert label_jids == joint_jids + assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp - def test_joint_label_fields(self, batch: list[dict]) -> None: - expected_fields = { - "independent_constraints", - "redundant_constraints", - "total_constraints", - } - for ex in batch: - for label in ex["joint_labels"].values(): - assert set(label.keys()) == expected_fields - - def test_joint_label_consistency(self, batch: list[dict]) -> None: + def test_joint_label_consistency(self) -> None: """independent + redundant == total for every joint.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(30) for ex in batch: for label in ex["joint_labels"].values(): total = label["independent_constraints"] + label["redundant_constraints"] @@ -157,10 +357,17 @@ class TestSeedReproducibility: g2 = SyntheticAssemblyGenerator(seed=2) b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) - # Very unlikely to be identical with different seeds c1 = [ex["assembly_classification"] for ex in b1] c2 = [ex["assembly_classification"] for ex in b2] r1 = [ex["is_rigid"] for ex in b1] r2 = [ex["is_rigid"] for ex in b2] - # At least one of these should differ (probabilistically certain) assert c1 != c2 or r1 != r2 + + def test_same_seed_identical(self) -> None: + g1 = SyntheticAssemblyGenerator(seed=123) + g2 = SyntheticAssemblyGenerator(seed=123) + b1, j1, _ = g1.generate_tree_assembly(5) + b2, j2, _ = g2.generate_tree_assembly(5) + for a, b in zip(b1, b2, strict=True): + np.testing.assert_array_almost_equal(a.position, b.position) + assert len(j1) == len(j2) From 78289494e274e3796a4bf9e2d5ef97feb3e14ec4 Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 14:57:49 -0600 Subject: [PATCH 10/12] feat: geometric diversity for synthetic assembly generation - Add AxisStrategy type (cardinal, random, near_parallel) - Add random body orientations via scipy.spatial.transform.Rotation - Add parallel axis injection with configurable probability - Add grounded parameter on all 7 generators (grounded/floating) - Add axis sampling strategies: cardinal, random, near-parallel - Update _create_joint with orientation-aware anchor offsets - Add _resolve_axis helper for parallel axis propagation - Update generate_training_batch with axis_strategy, parallel_axis_prob, grounded_ratio parameters - Add body_orientations and grounded fields to batch output - Export AxisStrategy from datagen package - Add 28 new tests (72 total generator tests, 158 total) Closes #8 --- solver/datagen/__init__.py | 7 +- solver/datagen/generator.py | 462 +++++++++++++++++++++++++++----- tests/datagen/test_generator.py | 321 +++++++++++++++++++++- 3 files changed, 710 insertions(+), 80 deletions(-) diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 203adcc..a1d0c34 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,7 +1,11 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly -from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator +from solver.datagen.generator import ( + COMPLEXITY_RANGES, + AxisStrategy, + SyntheticAssemblyGenerator, +) from solver.datagen.jacobian import JacobianVerifier from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( @@ -14,6 +18,7 @@ from solver.datagen.types import ( __all__ = [ "COMPLEXITY_RANGES", + "AxisStrategy", "ConstraintAnalysis", "JacobianVerifier", "Joint", diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index cd90ade..15df68f 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -10,6 +10,7 @@ from __future__ import annotations from typing import TYPE_CHECKING, Literal import numpy as np +from scipy.spatial.transform import Rotation from solver.datagen.analysis import analyze_assembly from solver.datagen.types import ( @@ -22,7 +23,12 @@ from solver.datagen.types import ( if TYPE_CHECKING: from typing import Any -__all__ = ["COMPLEXITY_RANGES", "ComplexityTier", "SyntheticAssemblyGenerator"] +__all__ = [ + "COMPLEXITY_RANGES", + "AxisStrategy", + "ComplexityTier", + "SyntheticAssemblyGenerator", +] # --------------------------------------------------------------------------- # Complexity tiers — ranges use exclusive upper bound for rng.integers() @@ -36,6 +42,12 @@ COMPLEXITY_RANGES: dict[str, tuple[int, int]] = { "complex": (16, 51), } +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + +AxisStrategy = Literal["cardinal", "random", "near_parallel"] + class SyntheticAssemblyGenerator: """Generates assembly graphs with known minimal constraint sets. @@ -67,6 +79,63 @@ class SyntheticAssemblyGenerator: axis /= np.linalg.norm(axis) return axis + def _random_orientation(self) -> np.ndarray: + """Generate a random 3x3 rotation matrix.""" + mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix() + return mat + + def _cardinal_axis(self) -> np.ndarray: + """Pick uniformly from the six signed cardinal directions.""" + axes = np.array( + [ + [1, 0, 0], + [-1, 0, 0], + [0, 1, 0], + [0, -1, 0], + [0, 0, 1], + [0, 0, -1], + ], + dtype=float, + ) + result: np.ndarray = axes[int(self.rng.integers(6))] + return result + + def _near_parallel_axis( + self, + base_axis: np.ndarray, + perturbation_scale: float = 0.05, + ) -> np.ndarray: + """Return *base_axis* with a small random perturbation, re-normalized.""" + perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale + return perturbed / np.linalg.norm(perturbed) + + def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray: + """Sample a joint axis using the specified strategy.""" + if strategy == "cardinal": + return self._cardinal_axis() + if strategy == "near_parallel": + return self._near_parallel_axis(np.array([0.0, 0.0, 1.0])) + return self._random_axis() + + def _resolve_axis( + self, + strategy: AxisStrategy, + parallel_axis_prob: float, + shared_axis: np.ndarray | None, + ) -> tuple[np.ndarray, np.ndarray | None]: + """Return (axis_for_this_joint, shared_axis_to_propagate). + + On the first call where *shared_axis* is ``None`` and parallel + injection triggers, a base axis is chosen and returned as + *shared_axis* for subsequent calls. + """ + if shared_axis is not None: + return self._near_parallel_axis(shared_axis), shared_axis + if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob: + base = self._sample_axis(strategy) + return base.copy(), base + return self._sample_axis(strategy), None + def _select_joint_type( self, joint_types: JointType | list[JointType], @@ -85,17 +154,37 @@ class SyntheticAssemblyGenerator: pos_a: np.ndarray, pos_b: np.ndarray, joint_type: JointType, + *, + axis: np.ndarray | None = None, + orient_a: np.ndarray | None = None, + orient_b: np.ndarray | None = None, ) -> Joint: - """Create a joint between two bodies with random axis at midpoint.""" - anchor = (pos_a + pos_b) / 2.0 + """Create a joint between two bodies. + + When orientations are provided, anchor points are offset from + each body's center along a random local direction rotated into + world frame, rather than placed at the midpoint. + """ + if orient_a is not None and orient_b is not None: + dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1) + offset_scale = dist * 0.2 + local_a = self.rng.standard_normal(3) * offset_scale + local_b = self.rng.standard_normal(3) * offset_scale + anchor_a = pos_a + orient_a @ local_a + anchor_b = pos_b + orient_b @ local_b + else: + anchor = (pos_a + pos_b) / 2.0 + anchor_a = anchor + anchor_b = anchor + return Joint( joint_id=joint_id, body_a=body_a_id, body_b=body_b_id, joint_type=joint_type, - anchor_a=anchor, - anchor_b=anchor, - axis=self._random_axis(), + anchor_a=anchor_a, + anchor_b=anchor_b, + axis=axis if axis is not None else self._random_axis(), ) # ------------------------------------------------------------------ @@ -106,6 +195,10 @@ class SyntheticAssemblyGenerator: self, n_bodies: int, joint_type: JointType = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a serial kinematic chain. @@ -118,27 +211,49 @@ class SyntheticAssemblyGenerator: for i in range(n_bodies): pos = np.array([i * 2.0, 0.0, 0.0]) - bodies.append(RigidBody(body_id=i, position=pos)) - - for i in range(n_bodies - 1): - anchor = np.array([(i + 0.5) * 2.0, 0.0, 0.0]) - joints.append( - Joint( - joint_id=i, - body_a=i, - body_b=i + 1, - joint_type=joint_type, - anchor_a=anchor, - anchor_b=anchor, - axis=self._random_axis(), + bodies.append( + RigidBody( + body_id=i, + position=pos, + orientation=self._random_orientation(), ) ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + shared_axis: np.ndarray | None = None + for i in range(n_bodies - 1): + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i, + i, + i + 1, + bodies[i].position, + bodies[i + 1].position, + joint_type, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[i + 1].orientation, + ) + ) + + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_rigid_assembly( - self, n_bodies: int + self, + n_bodies: int, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a minimally rigid assembly by adding joints until rigid. @@ -148,22 +263,35 @@ class SyntheticAssemblyGenerator: """ bodies = [] for i in range(n_bodies): - bodies.append(RigidBody(body_id=i, position=self._random_position())) + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) # Build spanning tree with fixed joints (overconstrained) joints: list[Joint] = [] + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): parent = int(self.rng.integers(0, i)) - mid = (bodies[i].position + bodies[parent].position) / 2 + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - Joint( - joint_id=i - 1, - body_a=parent, - body_b=i, - joint_type=JointType.FIXED, - anchor_a=mid, - anchor_b=mid, - axis=self._random_axis(), + self._create_joint( + i - 1, + parent, + i, + bodies[parent].position, + bodies[i].position, + JointType.FIXED, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, ) ) @@ -174,56 +302,73 @@ class SyntheticAssemblyGenerator: JointType.BALL, ] + ground = 0 if grounded else None for idx in self.rng.permutation(len(joints)): original_type = joints[idx].joint_type for candidate in weaker_types: joints[idx].joint_type = candidate - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly(bodies, joints, ground_body=ground) if analysis.is_rigid: break # Keep the weaker type else: joints[idx].joint_type = original_type - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly(bodies, joints, ground_body=ground) return bodies, joints, analysis def generate_overconstrained_assembly( self, n_bodies: int, extra_joints: int = 2, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate an assembly with known redundant constraints. Starts with a rigid assembly, then adds extra joints that the pebble game will flag as redundant. """ - bodies, joints, _ = self.generate_rigid_assembly(n_bodies) + bodies, joints, _ = self.generate_rigid_assembly( + n_bodies, + grounded=grounded, + axis_strategy=axis_strategy, + parallel_axis_prob=parallel_axis_prob, + ) joint_id = len(joints) + shared_axis: np.ndarray | None = None for _ in range(extra_joints): a, b = self.rng.choice(n_bodies, size=2, replace=False) - mid = (bodies[a].position + bodies[b].position) / 2 - _overcon_types = [ JointType.REVOLUTE, JointType.FIXED, JointType.BALL, ] jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))] + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - Joint( - joint_id=joint_id, - body_a=int(a), - body_b=int(b), - joint_type=jtype, - anchor_a=mid, - anchor_b=mid, - axis=self._random_axis(), + self._create_joint( + joint_id, + int(a), + int(b), + bodies[int(a)].position, + bodies[int(b)].position, + jtype, + axis=axis, + orient_a=bodies[int(a)].orientation, + orient_b=bodies[int(b)].orientation, ) ) joint_id += 1 - analysis = analyze_assembly(bodies, joints, ground_body=0) + ground = 0 if grounded else None + analysis = analyze_assembly(bodies, joints, ground_body=ground) return bodies, joints, analysis # ------------------------------------------------------------------ @@ -235,6 +380,10 @@ class SyntheticAssemblyGenerator: n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, branching_factor: int = 3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a random tree topology with configurable branching. @@ -247,12 +396,19 @@ class SyntheticAssemblyGenerator: joint_types: Single type or list to sample from per joint. branching_factor: Max children per parent (1-5 recommended). """ - bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=self._random_orientation(), + ) + ] joints: list[Joint] = [] available_parents = [0] next_id = 1 joint_id = 0 + shared_axis: np.ndarray | None = None while next_id < n_bodies and available_parents: pidx = int(self.rng.integers(0, len(available_parents))) @@ -266,11 +422,33 @@ class SyntheticAssemblyGenerator: direction = self._random_axis() distance = self.rng.uniform(1.5, 3.0) child_pos = parent_pos + direction * distance + child_orient = self._random_orientation() - bodies.append(RigidBody(body_id=next_id, position=child_pos)) + bodies.append( + RigidBody( + body_id=next_id, + position=child_pos, + orientation=child_orient, + ) + ) jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - self._create_joint(joint_id, parent_id, next_id, parent_pos, child_pos, jtype) + self._create_joint( + joint_id, + parent_id, + next_id, + parent_pos, + child_pos, + jtype, + axis=axis, + orient_a=bodies[parent_id].orientation, + orient_b=child_orient, + ) ) available_parents.append(next_id) @@ -283,13 +461,21 @@ class SyntheticAssemblyGenerator: if n_children >= branching_factor or self.rng.random() < 0.3: available_parents.pop(pidx) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_loop_assembly( self, n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a single closed loop (ring) of bodies. @@ -317,22 +503,52 @@ class SyntheticAssemblyGenerator: x = radius * np.cos(angle) y = radius * np.sin(angle) z = float(self.rng.uniform(-0.2, 0.2)) - bodies.append(RigidBody(body_id=i, position=np.array([x, y, z]))) + bodies.append( + RigidBody( + body_id=i, + position=np.array([x, y, z]), + orientation=self._random_orientation(), + ) + ) + shared_axis: np.ndarray | None = None for i in range(n_bodies): next_i = (i + 1) % n_bodies jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( - self._create_joint(i, i, next_i, bodies[i].position, bodies[next_i].position, jtype) + self._create_joint( + i, + i, + next_i, + bodies[i].position, + bodies[next_i].position, + jtype, + axis=axis, + orient_a=bodies[i].orientation, + orient_b=bodies[next_i].orientation, + ) ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_star_assembly( self, n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a star topology with central hub and satellites. @@ -350,19 +566,49 @@ class SyntheticAssemblyGenerator: msg = "Star assembly requires at least 2 bodies" raise ValueError(msg) - bodies: list[RigidBody] = [RigidBody(body_id=0, position=np.zeros(3))] + hub_orient = self._random_orientation() + bodies: list[RigidBody] = [ + RigidBody( + body_id=0, + position=np.zeros(3), + orientation=hub_orient, + ) + ] joints: list[Joint] = [] + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): direction = self._random_axis() distance = self.rng.uniform(2.0, 5.0) pos = direction * distance - bodies.append(RigidBody(body_id=i, position=pos)) + sat_orient = self._random_orientation() + bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient)) jtype = self._select_joint_type(joint_types) - joints.append(self._create_joint(i - 1, 0, i, np.zeros(3), pos, jtype)) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) + joints.append( + self._create_joint( + i - 1, + 0, + i, + np.zeros(3), + pos, + jtype, + axis=axis, + orient_a=hub_orient, + orient_b=sat_orient, + ) + ) - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis def generate_mixed_assembly( @@ -370,6 +616,10 @@ class SyntheticAssemblyGenerator: n_bodies: int, joint_types: JointType | list[JointType] = JointType.REVOLUTE, edge_density: float = 0.3, + *, + grounded: bool = True, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, ) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]: """Generate a mixed topology combining tree and loop elements. @@ -392,14 +642,26 @@ class SyntheticAssemblyGenerator: joints: list[Joint] = [] for i in range(n_bodies): - bodies.append(RigidBody(body_id=i, position=self._random_position())) + bodies.append( + RigidBody( + body_id=i, + position=self._random_position(), + orientation=self._random_orientation(), + ) + ) # Phase 1: spanning tree joint_id = 0 existing_edges: set[frozenset[int]] = set() + shared_axis: np.ndarray | None = None for i in range(1, n_bodies): parent = int(self.rng.integers(0, i)) jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( self._create_joint( joint_id, @@ -408,6 +670,9 @@ class SyntheticAssemblyGenerator: bodies[parent].position, bodies[i].position, jtype, + axis=axis, + orient_a=bodies[parent].orientation, + orient_b=bodies[i].orientation, ) ) existing_edges.add(frozenset([parent, i])) @@ -425,6 +690,11 @@ class SyntheticAssemblyGenerator: for a, b in candidates[:n_extra]: jtype = self._select_joint_type(joint_types) + axis, shared_axis = self._resolve_axis( + axis_strategy, + parallel_axis_prob, + shared_axis, + ) joints.append( self._create_joint( joint_id, @@ -433,11 +703,18 @@ class SyntheticAssemblyGenerator: bodies[a].position, bodies[b].position, jtype, + axis=axis, + orient_a=bodies[a].orientation, + orient_b=bodies[b].orientation, ) ) joint_id += 1 - analysis = analyze_assembly(bodies, joints, ground_body=0) + analysis = analyze_assembly( + bodies, + joints, + ground_body=0 if grounded else None, + ) return bodies, joints, analysis # ------------------------------------------------------------------ @@ -449,6 +726,10 @@ class SyntheticAssemblyGenerator: batch_size: int = 100, n_bodies_range: tuple[int, int] | None = None, complexity_tier: ComplexityTier | None = None, + *, + axis_strategy: AxisStrategy = "random", + parallel_axis_prob: float = 0.0, + grounded_ratio: float = 1.0, ) -> list[dict[str, Any]]: """Generate a batch of labeled training examples. @@ -461,6 +742,9 @@ class SyntheticAssemblyGenerator: Overridden by *complexity_tier* when both are given. complexity_tier: Predefined range (``"simple"`` / ``"medium"`` / ``"complex"``). Overrides *n_bodies_range*. + axis_strategy: Axis sampling strategy for joint axes. + parallel_axis_prob: Probability of parallel axis injection. + grounded_ratio: Fraction of examples that are grounded. """ if complexity_tier is not None: n_bodies_range = COMPLEXITY_RANGES[complexity_tier] @@ -474,11 +758,17 @@ class SyntheticAssemblyGenerator: JointType.FIXED, ] + geo_kw: dict[str, Any] = { + "axis_strategy": axis_strategy, + "parallel_axis_prob": parallel_axis_prob, + } + examples: list[dict[str, Any]] = [] for i in range(batch_size): n = int(self.rng.integers(*n_bodies_range)) gen_idx = int(self.rng.integers(7)) + grounded = bool(self.rng.random() < grounded_ratio) if gen_idx == 0: _chain_types = [ @@ -487,30 +777,66 @@ class SyntheticAssemblyGenerator: JointType.CYLINDRICAL, ] jtype = _chain_types[int(self.rng.integers(len(_chain_types)))] - bodies, joints, analysis = self.generate_chain_assembly(n, jtype) + bodies, joints, analysis = self.generate_chain_assembly( + n, + jtype, + grounded=grounded, + **geo_kw, + ) gen_name = "chain" elif gen_idx == 1: - bodies, joints, analysis = self.generate_rigid_assembly(n) + bodies, joints, analysis = self.generate_rigid_assembly( + n, + grounded=grounded, + **geo_kw, + ) gen_name = "rigid" elif gen_idx == 2: extra = int(self.rng.integers(1, 4)) - bodies, joints, analysis = self.generate_overconstrained_assembly(n, extra) + bodies, joints, analysis = self.generate_overconstrained_assembly( + n, + extra, + grounded=grounded, + **geo_kw, + ) gen_name = "overconstrained" elif gen_idx == 3: branching = int(self.rng.integers(2, 5)) - bodies, joints, analysis = self.generate_tree_assembly(n, _joint_pool, branching) + bodies, joints, analysis = self.generate_tree_assembly( + n, + _joint_pool, + branching, + grounded=grounded, + **geo_kw, + ) gen_name = "tree" elif gen_idx == 4: n = max(n, 3) - bodies, joints, analysis = self.generate_loop_assembly(n, _joint_pool) + bodies, joints, analysis = self.generate_loop_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) gen_name = "loop" elif gen_idx == 5: n = max(n, 2) - bodies, joints, analysis = self.generate_star_assembly(n, _joint_pool) + bodies, joints, analysis = self.generate_star_assembly( + n, + _joint_pool, + grounded=grounded, + **geo_kw, + ) gen_name = "star" else: density = float(self.rng.uniform(0.2, 0.5)) - bodies, joints, analysis = self.generate_mixed_assembly(n, _joint_pool, density) + bodies, joints, analysis = self.generate_mixed_assembly( + n, + _joint_pool, + density, + grounded=grounded, + **geo_kw, + ) gen_name = "mixed" # Build per-joint labels from edge results @@ -533,9 +859,11 @@ class SyntheticAssemblyGenerator: { "example_id": i, "generator_type": gen_name, + "grounded": grounded, "n_bodies": len(bodies), "n_joints": len(joints), "body_positions": [b.position.tolist() for b in bodies], + "body_orientations": [b.orientation.tolist() for b in bodies], "joints": [ { "joint_id": j.joint_id, @@ -547,11 +875,11 @@ class SyntheticAssemblyGenerator: for j in joints ], "joint_labels": joint_labels, - "assembly_classification": analysis.combinatorial_classification, + "assembly_classification": (analysis.combinatorial_classification), "is_rigid": analysis.is_rigid, "is_minimally_rigid": analysis.is_minimally_rigid, "internal_dof": analysis.jacobian_internal_dof, - "geometric_degeneracies": analysis.geometric_degeneracies, + "geometric_degeneracies": (analysis.geometric_degeneracies), } ) diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py index fa5d70f..dc70baa 100644 --- a/tests/datagen/test_generator.py +++ b/tests/datagen/test_generator.py @@ -44,7 +44,10 @@ class TestChainAssembly: def test_chain_custom_joint_type(self) -> None: gen = SyntheticAssemblyGenerator(seed=0) - _, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL) + _, joints, _ = gen.generate_chain_assembly( + 3, + joint_type=JointType.BALL, + ) assert all(j.joint_type is JointType.BALL for j in joints) @@ -77,7 +80,10 @@ class TestOverconstrainedAssembly: def test_has_redundant(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2) + _, _, analysis = gen.generate_overconstrained_assembly( + 4, + extra_joints=2, + ) assert analysis.combinatorial_redundant > 0 def test_extra_joints_added(self) -> None: @@ -85,7 +91,10 @@ class TestOverconstrainedAssembly: _, joints_base, _ = gen.generate_rigid_assembly(4) gen2 = SyntheticAssemblyGenerator(seed=42) - _, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3) + _, joints_over, _ = gen2.generate_overconstrained_assembly( + 4, + extra_joints=3, + ) # Overconstrained has base joints + extra assert len(joints_over) > len(joints_base) @@ -111,7 +120,10 @@ class TestTreeAssembly: def test_branching_factor(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2) + bodies, joints, _ = gen.generate_tree_assembly( + 10, + branching_factor=2, + ) assert len(bodies) == 10 assert len(joints) == 9 @@ -125,7 +137,10 @@ class TestTreeAssembly: def test_single_joint_type(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL) + _, joints, _ = gen.generate_tree_assembly( + 5, + joint_types=JointType.BALL, + ) assert all(j.joint_type is JointType.BALL for j in joints) def test_sequential_body_ids(self) -> None: @@ -206,12 +221,18 @@ class TestMixedAssembly: def test_more_joints_than_tree(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3) + bodies, joints, _ = gen.generate_mixed_assembly( + 8, + edge_density=0.3, + ) assert len(joints) > len(bodies) - 1 def test_density_zero_is_tree(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=0.0, + ) assert len(joints) == 4 # spanning tree only def test_density_validation(self) -> None: @@ -229,11 +250,210 @@ class TestMixedAssembly: def test_high_density(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - _bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0) + _bodies, joints, _ = gen.generate_mixed_assembly( + 5, + edge_density=1.0, + ) # Fully connected: 5*(5-1)/2 = 10 edges assert len(joints) == 10 +# --------------------------------------------------------------------------- +# Axis sampling strategies +# --------------------------------------------------------------------------- + + +class TestAxisStrategy: + """Axis sampling strategies produce valid unit vectors.""" + + def test_cardinal_axis_from_six(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axes = {tuple(gen._cardinal_axis()) for _ in range(200)} + expected = { + (1, 0, 0), + (-1, 0, 0), + (0, 1, 0), + (0, -1, 0), + (0, 0, 1), + (0, 0, -1), + } + assert axes == expected + + def test_random_axis_unit_norm(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + for _ in range(50): + axis = gen._sample_axis("random") + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + + def test_near_parallel_close_to_base(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + base = np.array([0.0, 0.0, 1.0]) + for _ in range(50): + axis = gen._near_parallel_axis(base) + assert abs(np.linalg.norm(axis) - 1.0) < 1e-10 + assert np.dot(axis, base) > 0.95 + + def test_sample_axis_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("cardinal") + cardinals = [ + np.array(v, dtype=float) + for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)] + ] + assert any(np.allclose(axis, c) for c in cardinals) + + def test_sample_axis_near_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=0) + axis = gen._sample_axis("near_parallel") + z = np.array([0.0, 0.0, 1.0]) + assert np.dot(axis, z) > 0.95 + + +# --------------------------------------------------------------------------- +# Geometric diversity: orientations +# --------------------------------------------------------------------------- + + +class TestRandomOrientations: + """Bodies should have non-identity orientations.""" + + def test_bodies_have_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_tree_assembly(5) + non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3))) + assert non_identity >= 3 + + def test_orientations_are_valid_rotations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + bodies, _, _ = gen.generate_star_assembly(6) + for b in bodies: + r = b.orientation + # R^T R == I + np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10) + # det(R) == 1 + assert abs(np.linalg.det(r) - 1.0) < 1e-10 + + def test_all_generators_set_orientations(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + # Chain + bodies, _, _ = gen.generate_chain_assembly(3) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Loop + bodies, _, _ = gen.generate_loop_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + # Mixed + bodies, _, _ = gen.generate_mixed_assembly(4) + assert not np.allclose(bodies[1].orientation, np.eye(3)) + + +# --------------------------------------------------------------------------- +# Geometric diversity: grounded parameter +# --------------------------------------------------------------------------- + + +class TestGroundedParameter: + """Grounded parameter controls ground_body in analysis.""" + + def test_chain_grounded_default(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly(4) + assert analysis.combinatorial_dof >= 0 + + def test_chain_floating(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, _, analysis = gen.generate_chain_assembly( + 4, + grounded=False, + ) + # Floating: 6 trivial DOF not subtracted by ground + assert analysis.combinatorial_dof >= 6 + + def test_floating_vs_grounded_dof_difference(self) -> None: + gen1 = SyntheticAssemblyGenerator(seed=42) + _, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True) + gen2 = SyntheticAssemblyGenerator(seed=42) + _, _, a_floating = gen2.generate_chain_assembly(4, grounded=False) + # Floating should have higher DOF due to missing ground constraint + assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof + + @pytest.mark.parametrize( + "gen_method", + [ + "generate_chain_assembly", + "generate_rigid_assembly", + "generate_tree_assembly", + "generate_loop_assembly", + "generate_star_assembly", + "generate_mixed_assembly", + ], + ) + def test_all_generators_accept_grounded( + self, + gen_method: str, + ) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + method = getattr(gen, gen_method) + n = 4 + # Should not raise + if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"): + method(n, grounded=False) + else: + method(n, grounded=False) + + +# --------------------------------------------------------------------------- +# Geometric diversity: parallel axis injection +# --------------------------------------------------------------------------- + + +class TestParallelAxisInjection: + """parallel_axis_prob causes shared axis direction.""" + + def test_parallel_axes_similar(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + # Near-parallel: |dot| close to 1 + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_zero_prob_no_forced_parallel(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_chain_assembly( + 6, + parallel_axis_prob=0.0, + ) + base = joints[0].axis + dots = [abs(np.dot(j.axis, base)) for j in joints[1:]] + # With 5 random axes, extremely unlikely all are parallel + assert min(dots) < 0.95 + + def test_parallel_on_loop(self) -> None: + """Parallel axes on a loop assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_loop_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + def test_parallel_on_star(self) -> None: + """Parallel axes on a star assembly.""" + gen = SyntheticAssemblyGenerator(seed=42) + _, joints, _ = gen.generate_star_assembly( + 5, + parallel_axis_prob=1.0, + ) + base = joints[0].axis + for j in joints[1:]: + assert abs(np.dot(j.axis, base)) > 0.9 + + # --------------------------------------------------------------------------- # Complexity tiers # --------------------------------------------------------------------------- @@ -265,7 +485,11 @@ class TestComplexityTiers: def test_tier_overrides_range(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) - batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium") + batch = gen.generate_training_batch( + 10, + n_bodies_range=(2, 3), + complexity_tier="medium", + ) lo, hi = COMPLEXITY_RANGES["medium"] for ex in batch: assert lo <= ex["n_bodies"] < hi @@ -282,9 +506,11 @@ class TestTrainingBatch: EXPECTED_KEYS: ClassVar[set[str]] = { "example_id", "generator_type", + "grounded", "n_bodies", "n_joints", "body_positions", + "body_orientations", "joints", "joint_labels", "assembly_classification", @@ -337,7 +563,8 @@ class TestTrainingBatch: gen = SyntheticAssemblyGenerator(seed=42) batch = gen.generate_training_batch(30) for ex in batch: - assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp + # default (3, 8), but loop/star may clamp + assert 2 <= ex["n_bodies"] <= 7 def test_joint_label_consistency(self) -> None: """independent + redundant == total for every joint.""" @@ -348,6 +575,70 @@ class TestTrainingBatch: total = label["independent_constraints"] + label["redundant_constraints"] assert total == label["total_constraints"] + def test_body_orientations_present(self) -> None: + """Each example includes body_orientations as 3x3 lists.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + orients = ex["body_orientations"] + assert len(orients) == ex["n_bodies"] + for o in orients: + assert len(o) == 3 + assert len(o[0]) == 3 + + def test_grounded_field_present(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + assert isinstance(ex["grounded"], bool) + + +# --------------------------------------------------------------------------- +# Batch grounded ratio +# --------------------------------------------------------------------------- + + +class TestBatchGroundedRatio: + """grounded_ratio controls the mix in batch generation.""" + + def test_all_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=1.0) + assert all(ex["grounded"] for ex in batch) + + def test_none_grounded(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(20, grounded_ratio=0.0) + assert not any(ex["grounded"] for ex in batch) + + def test_mixed_ratio(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(100, grounded_ratio=0.5) + grounded_count = sum(1 for ex in batch if ex["grounded"]) + # With 100 samples and p=0.5, should be roughly 50 +/- 20 + assert 20 < grounded_count < 80 + + def test_batch_axis_strategy_cardinal(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + axis_strategy="cardinal", + ) + assert len(batch) == 10 + + def test_batch_parallel_axis_prob(self) -> None: + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch( + 10, + parallel_axis_prob=0.5, + ) + assert len(batch) == 10 + + +# --------------------------------------------------------------------------- +# Seed reproducibility +# --------------------------------------------------------------------------- + class TestSeedReproducibility: """Different seeds produce different results.""" @@ -355,8 +646,14 @@ class TestSeedReproducibility: def test_different_seeds_differ(self) -> None: g1 = SyntheticAssemblyGenerator(seed=1) g2 = SyntheticAssemblyGenerator(seed=2) - b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) - b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6)) + b1 = g1.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) + b2 = g2.generate_training_batch( + batch_size=5, + n_bodies_range=(3, 6), + ) c1 = [ex["assembly_classification"] for ex in b1] c2 = [ex["assembly_classification"] for ex in b2] r1 = [ex["is_rigid"] for ex in b1] From 8a49f8ef40eaf232785010764685be159b4053cb Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Mon, 2 Feb 2026 15:20:02 -0600 Subject: [PATCH 11/12] feat: ground truth labeling pipeline - Create solver/datagen/labeling.py with label_assembly() function - Add dataclasses: ConstraintLabel, JointLabel, BodyDofLabel, AssemblyLabel, AssemblyLabels - Per-constraint labels: pebble_independent + jacobian_independent - Per-joint labels: aggregated independent/redundant/total counts - Per-body DOF: translational + rotational from nullspace projection - Assembly label: classification, total_dof, has_degeneracy flag - AssemblyLabels.to_dict() for JSON-serializable output - Integrate into generate_training_batch (adds 'labels' field) - Export AssemblyLabels and label_assembly from datagen package - Add 25 labeling tests + 1 batch structure test (184 total) Closes #9 --- solver/datagen/__init__.py | 3 + solver/datagen/generator.py | 7 + solver/datagen/labeling.py | 394 ++++++++++++++++++++++++++++++++ tests/datagen/test_generator.py | 12 + tests/datagen/test_labeling.py | 346 ++++++++++++++++++++++++++++ 5 files changed, 762 insertions(+) create mode 100644 solver/datagen/labeling.py create mode 100644 tests/datagen/test_labeling.py diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index a1d0c34..169907e 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -7,6 +7,7 @@ from solver.datagen.generator import ( SyntheticAssemblyGenerator, ) from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.labeling import AssemblyLabels, label_assembly from solver.datagen.pebble_game import PebbleGame3D from solver.datagen.types import ( ConstraintAnalysis, @@ -18,6 +19,7 @@ from solver.datagen.types import ( __all__ = [ "COMPLEXITY_RANGES", + "AssemblyLabels", "AxisStrategy", "ConstraintAnalysis", "JacobianVerifier", @@ -28,4 +30,5 @@ __all__ = [ "RigidBody", "SyntheticAssemblyGenerator", "analyze_assembly", + "label_assembly", ] diff --git a/solver/datagen/generator.py b/solver/datagen/generator.py index 15df68f..ad753f9 100644 --- a/solver/datagen/generator.py +++ b/solver/datagen/generator.py @@ -13,6 +13,7 @@ import numpy as np from scipy.spatial.transform import Rotation from solver.datagen.analysis import analyze_assembly +from solver.datagen.labeling import label_assembly from solver.datagen.types import ( ConstraintAnalysis, Joint, @@ -839,6 +840,11 @@ class SyntheticAssemblyGenerator: ) gen_name = "mixed" + # Produce ground truth labels (includes ConstraintAnalysis) + ground = 0 if grounded else None + labels = label_assembly(bodies, joints, ground_body=ground) + analysis = labels.analysis + # Build per-joint labels from edge results joint_labels: dict[int, dict[str, int]] = {} for result in analysis.per_edge_results: @@ -875,6 +881,7 @@ class SyntheticAssemblyGenerator: for j in joints ], "joint_labels": joint_labels, + "labels": labels.to_dict(), "assembly_classification": (analysis.combinatorial_classification), "is_rigid": analysis.is_rigid, "is_minimally_rigid": analysis.is_minimally_rigid, diff --git a/solver/datagen/labeling.py b/solver/datagen/labeling.py new file mode 100644 index 0000000..e03de42 --- /dev/null +++ b/solver/datagen/labeling.py @@ -0,0 +1,394 @@ +"""Ground truth labeling pipeline for synthetic assemblies. + +Produces rich per-constraint, per-joint, per-body, and assembly-level +labels by running both the pebble game and Jacobian verification and +correlating their results. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import numpy as np + +from solver.datagen.jacobian import JacobianVerifier +from solver.datagen.pebble_game import PebbleGame3D +from solver.datagen.types import ( + ConstraintAnalysis, + Joint, + JointType, + RigidBody, +) + +if TYPE_CHECKING: + from typing import Any + +__all__ = ["AssemblyLabels", "label_assembly"] + +_GROUND_ID = -1 +_SVD_TOL = 1e-8 + + +# --------------------------------------------------------------------------- +# Label dataclasses +# --------------------------------------------------------------------------- + + +@dataclass +class ConstraintLabel: + """Per scalar-constraint label combining both analysis methods.""" + + joint_id: int + constraint_idx: int + pebble_independent: bool + jacobian_independent: bool + + +@dataclass +class JointLabel: + """Aggregated constraint counts for a single joint.""" + + joint_id: int + independent_count: int + redundant_count: int + total: int + + +@dataclass +class BodyDofLabel: + """Per-body DOF signature from nullspace projection.""" + + body_id: int + translational_dof: int + rotational_dof: int + + +@dataclass +class AssemblyLabel: + """Assembly-wide summary label.""" + + classification: str + total_dof: int + redundant_count: int + is_rigid: bool + is_minimally_rigid: bool + has_degeneracy: bool + + +@dataclass +class AssemblyLabels: + """Complete ground truth labels for an assembly.""" + + per_constraint: list[ConstraintLabel] + per_joint: list[JointLabel] + per_body: list[BodyDofLabel] + assembly: AssemblyLabel + analysis: ConstraintAnalysis + + def to_dict(self) -> dict[str, Any]: + """Return a JSON-serializable dict.""" + return { + "per_constraint": [ + { + "joint_id": c.joint_id, + "constraint_idx": c.constraint_idx, + "pebble_independent": c.pebble_independent, + "jacobian_independent": c.jacobian_independent, + } + for c in self.per_constraint + ], + "per_joint": [ + { + "joint_id": j.joint_id, + "independent_count": j.independent_count, + "redundant_count": j.redundant_count, + "total": j.total, + } + for j in self.per_joint + ], + "per_body": [ + { + "body_id": b.body_id, + "translational_dof": b.translational_dof, + "rotational_dof": b.rotational_dof, + } + for b in self.per_body + ], + "assembly": { + "classification": self.assembly.classification, + "total_dof": self.assembly.total_dof, + "redundant_count": self.assembly.redundant_count, + "is_rigid": self.assembly.is_rigid, + "is_minimally_rigid": self.assembly.is_minimally_rigid, + "has_degeneracy": self.assembly.has_degeneracy, + }, + } + + +# --------------------------------------------------------------------------- +# Per-body DOF from nullspace projection +# --------------------------------------------------------------------------- + + +def _compute_per_body_dof( + j_reduced: np.ndarray, + body_ids: list[int], + ground_body: int | None, + body_index: dict[int, int], +) -> list[BodyDofLabel]: + """Compute translational and rotational DOF per body. + + Uses SVD nullspace projection: for each body, extract its + translational (3 cols) and rotational (3 cols) components + from the nullspace basis and compute ranks. + """ + # Build column index mapping for the reduced Jacobian + # (ground body columns have been removed) + col_map: dict[int, int] = {} + col_idx = 0 + for bid in body_ids: + if bid == ground_body: + continue + col_map[bid] = col_idx + col_idx += 1 + + results: list[BodyDofLabel] = [] + + if j_reduced.size == 0: + # No constraints — every body is fully free + for bid in body_ids: + if bid == ground_body: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + else: + results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3)) + return results + + # Full SVD to get nullspace + _u, s, vh = np.linalg.svd(j_reduced, full_matrices=True) + rank = int(np.sum(s > _SVD_TOL)) + n_cols = j_reduced.shape[1] + + if rank >= n_cols: + # Fully constrained — no nullspace + for bid in body_ids: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + return results + + # Nullspace basis: rows of Vh beyond the rank + nullspace = vh[rank:] # shape: (n_cols - rank, n_cols) + + for bid in body_ids: + if bid == ground_body: + results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0)) + continue + + idx = col_map[bid] + trans_cols = nullspace[:, idx * 6 : idx * 6 + 3] + rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6] + + # Rank of each block = DOF in that category + if trans_cols.size > 0: + sv_t = np.linalg.svd(trans_cols, compute_uv=False) + t_dof = int(np.sum(sv_t > _SVD_TOL)) + else: + t_dof = 0 + + if rot_cols.size > 0: + sv_r = np.linalg.svd(rot_cols, compute_uv=False) + r_dof = int(np.sum(sv_r > _SVD_TOL)) + else: + r_dof = 0 + + results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof)) + + return results + + +# --------------------------------------------------------------------------- +# Main labeling function +# --------------------------------------------------------------------------- + + +def label_assembly( + bodies: list[RigidBody], + joints: list[Joint], + ground_body: int | None = None, +) -> AssemblyLabels: + """Produce complete ground truth labels for an assembly. + + Runs both the pebble game and Jacobian verification internally, + then correlates their results into per-constraint, per-joint, + per-body, and assembly-level labels. + + Args: + bodies: Rigid bodies in the assembly. + joints: Joints connecting the bodies. + ground_body: If set, this body is fixed to the world. + + Returns: + AssemblyLabels with full label set and embedded ConstraintAnalysis. + """ + # ---- Pebble Game ---- + pg = PebbleGame3D() + all_edge_results: list[dict[str, Any]] = [] + + if ground_body is not None: + pg.add_body(_GROUND_ID) + + for body in bodies: + pg.add_body(body.body_id) + + if ground_body is not None: + ground_joint = Joint( + joint_id=-1, + body_a=ground_body, + body_b=_GROUND_ID, + joint_type=JointType.FIXED, + anchor_a=bodies[0].position if bodies else np.zeros(3), + anchor_b=bodies[0].position if bodies else np.zeros(3), + ) + pg.add_joint(ground_joint) + + for joint in joints: + results = pg.add_joint(joint) + all_edge_results.extend(results) + + grounded = ground_body is not None + combinatorial_independent = len(pg.state.independent_edges) + raw_dof = pg.get_dof() + ground_offset = 6 if grounded else 0 + effective_dof = raw_dof - ground_offset + effective_internal_dof = max(0, effective_dof - (0 if grounded else 6)) + + redundant_count = pg.get_redundant_count() + if redundant_count > 0 and effective_internal_dof > 0: + classification = "mixed" + elif redundant_count > 0: + classification = "overconstrained" + elif effective_internal_dof > 0: + classification = "underconstrained" + else: + classification = "well-constrained" + + # ---- Jacobian Verification ---- + verifier = JacobianVerifier(bodies) + + for joint in joints: + verifier.add_joint_constraints(joint) + + j_full = verifier.get_jacobian() + j_reduced = j_full.copy() + if ground_body is not None and j_reduced.size > 0: + idx = verifier.body_index[ground_body] + cols_to_remove = list(range(idx * 6, (idx + 1) * 6)) + j_reduced = np.delete(j_reduced, cols_to_remove, axis=1) + + if j_reduced.size > 0: + sv = np.linalg.svd(j_reduced, compute_uv=False) + jacobian_rank = int(np.sum(sv > _SVD_TOL)) + else: + jacobian_rank = 0 + + n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies) + jacobian_nullity = n_cols - jacobian_rank + dependent_rows = verifier.find_dependencies() + dependent_set = set(dependent_rows) + + trivial_dof = 0 if grounded else 6 + jacobian_internal_dof = jacobian_nullity - trivial_dof + geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank) + is_rigid = jacobian_nullity <= trivial_dof + is_minimally_rigid = is_rigid and len(dependent_rows) == 0 + + # ---- Per-constraint labels ---- + # Map Jacobian rows to (joint_id, constraint_index). + # Rows are added contiguously per joint in the same order as joints. + row_to_joint: list[tuple[int, int]] = [] + for joint in joints: + dof = joint.joint_type.dof + for ci in range(dof): + row_to_joint.append((joint.joint_id, ci)) + + per_constraint: list[ConstraintLabel] = [] + for edge_idx, edge_result in enumerate(all_edge_results): + jid = edge_result["joint_id"] + ci = edge_result["constraint_index"] + pebble_indep = edge_result["independent"] + + # Find matching Jacobian row + jacobian_indep = True + if edge_idx < len(row_to_joint): + row_idx = edge_idx + jacobian_indep = row_idx not in dependent_set + + per_constraint.append( + ConstraintLabel( + joint_id=jid, + constraint_idx=ci, + pebble_independent=pebble_indep, + jacobian_independent=jacobian_indep, + ) + ) + + # ---- Per-joint labels ---- + joint_agg: dict[int, JointLabel] = {} + for cl in per_constraint: + if cl.joint_id not in joint_agg: + joint_agg[cl.joint_id] = JointLabel( + joint_id=cl.joint_id, + independent_count=0, + redundant_count=0, + total=0, + ) + jl = joint_agg[cl.joint_id] + jl.total += 1 + if cl.pebble_independent: + jl.independent_count += 1 + else: + jl.redundant_count += 1 + + per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg] + + # ---- Per-body DOF labels ---- + body_ids = [b.body_id for b in bodies] + per_body = _compute_per_body_dof( + j_reduced, + body_ids, + ground_body, + verifier.body_index, + ) + + # ---- Assembly label ---- + assembly_label = AssemblyLabel( + classification=classification, + total_dof=max(0, jacobian_internal_dof), + redundant_count=redundant_count, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + has_degeneracy=geometric_degeneracies > 0, + ) + + # ---- ConstraintAnalysis (for backward compat) ---- + analysis = ConstraintAnalysis( + combinatorial_dof=effective_dof, + combinatorial_internal_dof=effective_internal_dof, + combinatorial_redundant=redundant_count, + combinatorial_classification=classification, + per_edge_results=all_edge_results, + jacobian_rank=jacobian_rank, + jacobian_nullity=jacobian_nullity, + jacobian_internal_dof=max(0, jacobian_internal_dof), + numerically_dependent=dependent_rows, + geometric_degeneracies=geometric_degeneracies, + is_rigid=is_rigid, + is_minimally_rigid=is_minimally_rigid, + ) + + return AssemblyLabels( + per_constraint=per_constraint, + per_joint=per_joint, + per_body=per_body, + assembly=assembly_label, + analysis=analysis, + ) diff --git a/tests/datagen/test_generator.py b/tests/datagen/test_generator.py index dc70baa..3b120fa 100644 --- a/tests/datagen/test_generator.py +++ b/tests/datagen/test_generator.py @@ -513,6 +513,7 @@ class TestTrainingBatch: "body_orientations", "joints", "joint_labels", + "labels", "assembly_classification", "is_rigid", "is_minimally_rigid", @@ -586,6 +587,17 @@ class TestTrainingBatch: assert len(o) == 3 assert len(o[0]) == 3 + def test_labels_structure(self) -> None: + """Each example has labels dict with expected sub-keys.""" + gen = SyntheticAssemblyGenerator(seed=42) + batch = gen.generate_training_batch(10) + for ex in batch: + labels = ex["labels"] + assert "per_constraint" in labels + assert "per_joint" in labels + assert "per_body" in labels + assert "assembly" in labels + def test_grounded_field_present(self) -> None: gen = SyntheticAssemblyGenerator(seed=42) batch = gen.generate_training_batch(10) diff --git a/tests/datagen/test_labeling.py b/tests/datagen/test_labeling.py new file mode 100644 index 0000000..a81e2e3 --- /dev/null +++ b/tests/datagen/test_labeling.py @@ -0,0 +1,346 @@ +"""Tests for solver.datagen.labeling -- ground truth labeling pipeline.""" + +from __future__ import annotations + +import json + +import numpy as np + +from solver.datagen.labeling import ( + label_assembly, +) +from solver.datagen.types import Joint, JointType, RigidBody + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]: + return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)] + + +def _make_joint( + jid: int, + a: int, + b: int, + jtype: JointType, + axis: tuple[float, ...] = (0.0, 0.0, 1.0), +) -> Joint: + return Joint( + joint_id=jid, + body_a=a, + body_b=b, + joint_type=jtype, + anchor_a=np.zeros(3), + anchor_b=np.zeros(3), + axis=np.array(axis), + ) + + +# --------------------------------------------------------------------------- +# Per-constraint labels +# --------------------------------------------------------------------------- + + +class TestConstraintLabels: + """Per-constraint labels combine pebble game and Jacobian results.""" + + def test_fixed_joint_all_independent(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 6 + for cl in labels.per_constraint: + assert cl.pebble_independent is True + assert cl.jacobian_independent is True + + def test_revolute_joint_all_independent(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 5 + for cl in labels.per_constraint: + assert cl.pebble_independent is True + assert cl.jacobian_independent is True + + def test_chain_constraint_count(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_constraint) == 10 # 5 + 5 + + def test_constraint_joint_ids(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.BALL), + ] + labels = label_assembly(bodies, joints, ground_body=0) + j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0] + j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1] + assert len(j0_constraints) == 5 # revolute + assert len(j1_constraints) == 3 # ball + + def test_overconstrained_has_pebble_redundant(self) -> None: + """Triangle with revolute joints: some constraints redundant.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent) + assert pebble_redundant > 0 + + +# --------------------------------------------------------------------------- +# Per-joint labels +# --------------------------------------------------------------------------- + + +class TestJointLabels: + """Per-joint aggregated labels.""" + + def test_fixed_joint_counts(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_joint) == 1 + jl = labels.per_joint[0] + assert jl.joint_id == 0 + assert jl.independent_count == 6 + assert jl.redundant_count == 0 + assert jl.total == 6 + + def test_overconstrained_has_redundant_joints(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + total_redundant = sum(jl.redundant_count for jl in labels.per_joint) + assert total_redundant > 0 + + def test_joint_total_equals_dof(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.BALL)] + labels = label_assembly(bodies, joints, ground_body=0) + jl = labels.per_joint[0] + assert jl.total == 3 # ball has 3 DOF + + +# --------------------------------------------------------------------------- +# Per-body DOF labels +# --------------------------------------------------------------------------- + + +class TestBodyDofLabels: + """Per-body DOF signatures from nullspace projection.""" + + def test_fixed_joint_grounded_both_zero(self) -> None: + """Two bodies + fixed joint + grounded: both fully constrained.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + for bl in labels.per_body: + assert bl.translational_dof == 0 + assert bl.rotational_dof == 0 + + def test_revolute_has_rotational_dof(self) -> None: + """Two bodies + revolute + grounded: body 1 has rotational DOF.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + b1 = next(b for b in labels.per_body if b.body_id == 1) + # Revolute allows 1 rotation DOF + assert b1.rotational_dof >= 1 + + def test_dof_bounds(self) -> None: + """All DOF values should be in [0, 3].""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + for bl in labels.per_body: + assert 0 <= bl.translational_dof <= 3 + assert 0 <= bl.rotational_dof <= 3 + + def test_floating_more_dof_than_grounded(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + grounded = label_assembly(bodies, joints, ground_body=0) + floating = label_assembly(bodies, joints, ground_body=None) + g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body) + f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body) + assert f_total > g_total + + def test_grounded_body_zero_dof(self) -> None: + """The grounded body should have 0 DOF.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + b0 = next(b for b in labels.per_body if b.body_id == 0) + assert b0.translational_dof == 0 + assert b0.rotational_dof == 0 + + def test_body_count_matches(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.BALL), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert len(labels.per_body) == 3 + + +# --------------------------------------------------------------------------- +# Assembly label +# --------------------------------------------------------------------------- + + +class TestAssemblyLabel: + """Assembly-wide summary labels.""" + + def test_underconstrained_chain(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.classification == "underconstrained" + assert labels.assembly.is_rigid is False + + def test_well_constrained(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.FIXED)] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.classification == "well-constrained" + assert labels.assembly.is_rigid is True + assert labels.assembly.is_minimally_rigid is True + + def test_overconstrained(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE), + _make_joint(1, 1, 2, JointType.REVOLUTE), + _make_joint(2, 2, 0, JointType.REVOLUTE), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.redundant_count > 0 + + def test_has_degeneracy_with_parallel_axes(self) -> None: + """Parallel revolute axes in a loop create geometric degeneracy.""" + z_axis = (0.0, 0.0, 1.0) + bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0)) + joints = [ + _make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis), + _make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis), + _make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis), + _make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis), + ] + labels = label_assembly(bodies, joints, ground_body=0) + assert labels.assembly.has_degeneracy is True + + +# --------------------------------------------------------------------------- +# Serialization +# --------------------------------------------------------------------------- + + +class TestToDict: + """to_dict produces JSON-serializable output.""" + + def test_top_level_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + assert set(d.keys()) == { + "per_constraint", + "per_joint", + "per_body", + "assembly", + } + + def test_per_constraint_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + for item in d["per_constraint"]: + assert set(item.keys()) == { + "joint_id", + "constraint_idx", + "pebble_independent", + "jacobian_independent", + } + + def test_assembly_keys(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + assert set(d["assembly"].keys()) == { + "classification", + "total_dof", + "redundant_count", + "is_rigid", + "is_minimally_rigid", + "has_degeneracy", + } + + def test_json_serializable(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + d = labels.to_dict() + # Should not raise + serialized = json.dumps(d) + assert isinstance(serialized, str) + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestLabelAssemblyEdgeCases: + """Edge cases for label_assembly.""" + + def test_no_joints(self) -> None: + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + labels = label_assembly(bodies, [], ground_body=0) + assert len(labels.per_constraint) == 0 + assert len(labels.per_joint) == 0 + assert labels.assembly.classification == "underconstrained" + # Non-ground body should be fully free + b1 = next(b for b in labels.per_body if b.body_id == 1) + assert b1.translational_dof == 3 + assert b1.rotational_dof == 3 + + def test_no_joints_floating(self) -> None: + bodies = _make_bodies((0, 0, 0)) + labels = label_assembly(bodies, [], ground_body=None) + assert len(labels.per_body) == 1 + assert labels.per_body[0].translational_dof == 3 + assert labels.per_body[0].rotational_dof == 3 + + def test_analysis_embedded(self) -> None: + """AssemblyLabels.analysis should be a valid ConstraintAnalysis.""" + bodies = _make_bodies((0, 0, 0), (2, 0, 0)) + joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)] + labels = label_assembly(bodies, joints, ground_body=0) + analysis = labels.analysis + assert hasattr(analysis, "combinatorial_classification") + assert hasattr(analysis, "jacobian_rank") + assert hasattr(analysis, "is_rigid") From f29060491e5b99d466911a5f500040becf6850ec Mon Sep 17 00:00:00 2001 From: forbes-0023 Date: Tue, 3 Feb 2026 08:44:31 -0600 Subject: [PATCH 12/12] feat(datagen): add dataset generation CLI with sharding and checkpointing - Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator, ShardSpec/ShardResult dataclasses, parallel shard generation via ProcessPoolExecutor, checkpoint/resume support, index and stats output - Add scripts/generate_synthetic.py CLI entry point with Hydra-first and argparse fallback modes - Add minimal YAML parser (parse_simple_yaml) for config loading without PyYAML dependency - Add progress display with tqdm fallback to print-based ETA - Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every - Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator exports - Add tests/datagen/test_dataset.py with 28 tests covering config, YAML parsing, seed derivation, end-to-end generation, resume, stats/index structure, determinism, and CLI integration Closes #10 --- configs/dataset/synthetic.yaml | 2 + scripts/generate_synthetic.py | 115 ++++++ solver/datagen/__init__.py | 3 + solver/datagen/dataset.py | 624 +++++++++++++++++++++++++++++++++ tests/datagen/test_dataset.py | 337 ++++++++++++++++++ 5 files changed, 1081 insertions(+) create mode 100644 scripts/generate_synthetic.py create mode 100644 solver/datagen/dataset.py create mode 100644 tests/datagen/test_dataset.py diff --git a/configs/dataset/synthetic.yaml b/configs/dataset/synthetic.yaml index 8450ad8..5f8c117 100644 --- a/configs/dataset/synthetic.yaml +++ b/configs/dataset/synthetic.yaml @@ -2,6 +2,7 @@ name: synthetic num_assemblies: 100000 output_dir: data/synthetic +shard_size: 1000 complexity_distribution: simple: 0.4 # 2-5 bodies @@ -22,3 +23,4 @@ templates: grounded_ratio: 0.5 seed: 42 num_workers: 4 +checkpoint_every: 5 diff --git a/scripts/generate_synthetic.py b/scripts/generate_synthetic.py new file mode 100644 index 0000000..9a7f99f --- /dev/null +++ b/scripts/generate_synthetic.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Generate synthetic assembly dataset for kindred-solver training. + +Usage (argparse fallback — always available):: + + python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4 + +Usage (Hydra — when hydra-core is installed):: + + python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4 +""" + +from __future__ import annotations + + +def _try_hydra_main() -> bool: + """Attempt to run via Hydra. Returns *True* if Hydra handled it.""" + try: + import hydra # type: ignore[import-untyped] + from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped] + except ImportError: + return False + + @hydra.main( + config_path="../configs/dataset", + config_name="synthetic", + version_base=None, + ) + def _run(cfg: DictConfig) -> None: # type: ignore[type-arg] + from solver.datagen.dataset import DatasetConfig, DatasetGenerator + + config_dict = OmegaConf.to_container(cfg, resolve=True) + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + _run() # type: ignore[no-untyped-call] + return True + + +def _argparse_main() -> None: + """Fallback CLI using argparse.""" + import argparse + + parser = argparse.ArgumentParser( + description="Generate synthetic assembly dataset", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to YAML config file (optional)", + ) + parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory") + parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard") + parser.add_argument("--body-count-min", type=int, default=None, help="Min body count") + parser.add_argument("--body-count-max", type=int, default=None, help="Max body count") + parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers") + parser.add_argument( + "--checkpoint-every", + type=int, + default=None, + help="Checkpoint interval (shards)", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Do not resume from existing checkpoints", + ) + + args = parser.parse_args() + + from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + parse_simple_yaml, + ) + + config_dict: dict[str, object] = {} + if args.config: + config_dict = parse_simple_yaml(args.config) # type: ignore[assignment] + + # CLI args override config file (only when explicitly provided) + _override_map = { + "num_assemblies": args.num_assemblies, + "output_dir": args.output_dir, + "shard_size": args.shard_size, + "body_count_min": args.body_count_min, + "body_count_max": args.body_count_max, + "grounded_ratio": args.grounded_ratio, + "seed": args.seed, + "num_workers": args.num_workers, + "checkpoint_every": args.checkpoint_every, + } + for key, val in _override_map.items(): + if val is not None: + config_dict[key] = val + + if args.no_resume: + config_dict["resume"] = False + + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + +def main() -> None: + """Entry point: try Hydra first, fall back to argparse.""" + if not _try_hydra_main(): + _argparse_main() + + +if __name__ == "__main__": + main() diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 169907e..7e26e21 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,6 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly +from solver.datagen.dataset import DatasetConfig, DatasetGenerator from solver.datagen.generator import ( COMPLEXITY_RANGES, AxisStrategy, @@ -22,6 +23,8 @@ __all__ = [ "AssemblyLabels", "AxisStrategy", "ConstraintAnalysis", + "DatasetConfig", + "DatasetGenerator", "JacobianVerifier", "Joint", "JointType", diff --git a/solver/datagen/dataset.py b/solver/datagen/dataset.py new file mode 100644 index 0000000..e5a4ce2 --- /dev/null +++ b/solver/datagen/dataset.py @@ -0,0 +1,624 @@ +"""Dataset generation orchestrator with sharding, checkpointing, and statistics. + +Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator` +for parallel generation of synthetic assembly training data. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import math +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "DatasetConfig", + "DatasetGenerator", + "parse_simple_yaml", +] + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class DatasetConfig: + """Configuration for synthetic dataset generation.""" + + name: str = "synthetic" + num_assemblies: int = 100_000 + output_dir: str = "data/synthetic" + shard_size: int = 1000 + complexity_distribution: dict[str, float] = field( + default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2} + ) + body_count_min: int = 2 + body_count_max: int = 50 + templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"]) + grounded_ratio: float = 0.5 + seed: int = 42 + num_workers: int = 4 + checkpoint_every: int = 5 + resume: bool = True + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> DatasetConfig: + """Construct from a parsed config dict (e.g. YAML or OmegaConf). + + Handles both flat keys (``body_count_min``) and nested forms + (``body_count: {min: 2, max: 50}``). + """ + kw: dict[str, Any] = {} + for key in ( + "name", + "num_assemblies", + "output_dir", + "shard_size", + "grounded_ratio", + "seed", + "num_workers", + "checkpoint_every", + "resume", + ): + if key in d: + kw[key] = d[key] + + # Handle nested body_count dict + if "body_count" in d and isinstance(d["body_count"], dict): + bc = d["body_count"] + if "min" in bc: + kw["body_count_min"] = int(bc["min"]) + if "max" in bc: + kw["body_count_max"] = int(bc["max"]) + else: + if "body_count_min" in d: + kw["body_count_min"] = int(d["body_count_min"]) + if "body_count_max" in d: + kw["body_count_max"] = int(d["body_count_max"]) + + if "complexity_distribution" in d: + cd = d["complexity_distribution"] + if isinstance(cd, dict): + kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()} + if "templates" in d and isinstance(d["templates"], list): + kw["templates"] = [str(t) for t in d["templates"]] + + return cls(**kw) + + +# --------------------------------------------------------------------------- +# Shard specification / result +# --------------------------------------------------------------------------- + + +@dataclass +class ShardSpec: + """Specification for generating a single shard.""" + + shard_id: int + start_example_id: int + count: int + seed: int + complexity_distribution: dict[str, float] + body_count_min: int + body_count_max: int + grounded_ratio: float + + +@dataclass +class ShardResult: + """Result returned from a shard worker.""" + + shard_id: int + num_examples: int + file_path: str + generation_time_s: float + + +# --------------------------------------------------------------------------- +# Seed derivation +# --------------------------------------------------------------------------- + + +def _derive_shard_seed(global_seed: int, shard_id: int) -> int: + """Derive a deterministic per-shard seed from the global seed.""" + h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest() + return int(h[:8], 16) + + +# --------------------------------------------------------------------------- +# Progress display +# --------------------------------------------------------------------------- + + +class _PrintProgress: + """Fallback progress display when tqdm is unavailable.""" + + def __init__(self, total: int) -> None: + self.total = total + self.current = 0 + self.start_time = time.monotonic() + + def update(self, n: int = 1) -> None: + self.current += n + elapsed = time.monotonic() - self.start_time + rate = self.current / elapsed if elapsed > 0 else 0.0 + eta = (self.total - self.current) / rate if rate > 0 else 0.0 + pct = 100.0 * self.current / self.total + sys.stdout.write( + f"\r[{pct:5.1f}%] {self.current}/{self.total} shards" + f" | {rate:.1f} shards/s | ETA: {eta:.0f}s" + ) + sys.stdout.flush() + + def close(self) -> None: + sys.stdout.write("\n") + sys.stdout.flush() + + +def _make_progress(total: int) -> _PrintProgress: + """Create a progress tracker (tqdm if available, else print-based).""" + try: + from tqdm import tqdm # type: ignore[import-untyped] + + return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value] + except ImportError: + return _PrintProgress(total) + + +# --------------------------------------------------------------------------- +# Shard I/O +# --------------------------------------------------------------------------- + + +def _save_shard( + shard_id: int, + examples: list[dict[str, Any]], + shards_dir: Path, +) -> Path: + """Save a shard to disk (.pt if torch available, else .json).""" + shards_dir.mkdir(parents=True, exist_ok=True) + try: + import torch # type: ignore[import-untyped] + + path = shards_dir / f"shard_{shard_id:05d}.pt" + torch.save(examples, path) + except ImportError: + path = shards_dir / f"shard_{shard_id:05d}.json" + with open(path, "w") as f: + json.dump(examples, f) + return path + + +def _load_shard(path: Path) -> list[dict[str, Any]]: + """Load a shard from disk (.pt or .json).""" + if path.suffix == ".pt": + import torch # type: ignore[import-untyped] + + result: list[dict[str, Any]] = torch.load(path, weights_only=False) + return result + with open(path) as f: + result = json.load(f) + return result + + +def _shard_format() -> str: + """Return the shard file extension based on available libraries.""" + try: + import torch # type: ignore[import-untyped] # noqa: F401 + + return ".pt" + except ImportError: + return ".json" + + +# --------------------------------------------------------------------------- +# Shard worker (module-level for pickling) +# --------------------------------------------------------------------------- + + +def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult: + """Generate a single shard — top-level function for ProcessPoolExecutor.""" + from solver.datagen.generator import SyntheticAssemblyGenerator + + t0 = time.monotonic() + gen = SyntheticAssemblyGenerator(seed=spec.seed) + rng = np.random.default_rng(spec.seed + 1) + + tiers = list(spec.complexity_distribution.keys()) + probs_list = [spec.complexity_distribution[t] for t in tiers] + total = sum(probs_list) + probs = [p / total for p in probs_list] + + examples: list[dict[str, Any]] = [] + for i in range(spec.count): + tier_idx = int(rng.choice(len(tiers), p=probs)) + tier = tiers[tier_idx] + try: + batch = gen.generate_training_batch( + batch_size=1, + complexity_tier=tier, # type: ignore[arg-type] + grounded_ratio=spec.grounded_ratio, + ) + ex = batch[0] + ex["example_id"] = spec.start_example_id + i + ex["complexity_tier"] = tier + examples.append(ex) + except Exception: + logger.warning( + "Shard %d, example %d failed — skipping", + spec.shard_id, + i, + exc_info=True, + ) + + shards_dir = Path(output_dir) / "shards" + path = _save_shard(spec.shard_id, examples, shards_dir) + + elapsed = time.monotonic() - t0 + return ShardResult( + shard_id=spec.shard_id, + num_examples=len(examples), + file_path=str(path), + generation_time_s=elapsed, + ) + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +def _parse_scalar(value: str) -> int | float | bool | str: + """Parse a YAML scalar value.""" + # Strip inline comments (space + #) + if " #" in value: + value = value[: value.index(" #")].strip() + elif " #" in value: + value = value[: value.index(" #")].strip() + v = value.strip() + if v.lower() in ("true", "yes"): + return True + if v.lower() in ("false", "no"): + return False + try: + return int(v) + except ValueError: + pass + try: + return float(v) + except ValueError: + pass + return v.strip("'\"") + + +def parse_simple_yaml(path: str) -> dict[str, Any]: + """Parse a simple YAML file (flat scalars, one-level dicts, lists). + + This is **not** a full YAML parser. It handles the structure of + ``configs/dataset/synthetic.yaml``. + """ + result: dict[str, Any] = {} + current_key: str | None = None + + with open(path) as f: + for raw_line in f: + line = raw_line.rstrip() + + # Skip blank lines and full-line comments + if not line or line.lstrip().startswith("#"): + continue + + indent = len(line) - len(line.lstrip()) + + if indent == 0 and ":" in line: + key, _, value = line.partition(":") + key = key.strip() + value = value.strip() + if value: + result[key] = _parse_scalar(value) + current_key = None + else: + current_key = key + result[key] = {} + continue + + if indent > 0 and line.lstrip().startswith("- "): + item = line.lstrip()[2:].strip() + if current_key is not None: + if isinstance(result.get(current_key), dict) and not result[current_key]: + result[current_key] = [] + if isinstance(result.get(current_key), list): + result[current_key].append(_parse_scalar(item)) + continue + + if indent > 0 and ":" in line and current_key is not None: + k, _, v = line.partition(":") + k = k.strip() + v = v.strip() + if v: + # Strip inline comments + if " #" in v: + v = v[: v.index(" #")].strip() + if not isinstance(result.get(current_key), dict): + result[current_key] = {} + result[current_key][k] = _parse_scalar(v) + continue + + return result + + +# --------------------------------------------------------------------------- +# Dataset generator orchestrator +# --------------------------------------------------------------------------- + + +class DatasetGenerator: + """Orchestrates parallel dataset generation with sharding and checkpointing.""" + + def __init__(self, config: DatasetConfig) -> None: + self.config = config + self.output_path = Path(config.output_dir) + self.shards_dir = self.output_path / "shards" + self.checkpoint_file = self.output_path / ".checkpoint.json" + self.index_file = self.output_path / "index.json" + self.stats_file = self.output_path / "stats.json" + + # -- public API -- + + def run(self) -> None: + """Generate the full dataset.""" + self.output_path.mkdir(parents=True, exist_ok=True) + self.shards_dir.mkdir(parents=True, exist_ok=True) + + shards = self._plan_shards() + total_shards = len(shards) + + # Resume: find already-completed shards + completed: set[int] = set() + if self.config.resume: + completed = self._find_completed_shards() + + pending = [s for s in shards if s.shard_id not in completed] + + if not pending: + logger.info("All %d shards already complete.", total_shards) + else: + logger.info( + "Generating %d shards (%d already complete).", + len(pending), + len(completed), + ) + progress = _make_progress(len(pending)) + workers = max(1, self.config.num_workers) + checkpoint_counter = 0 + + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit( + _generate_shard_worker, + spec, + str(self.output_path), + ): spec.shard_id + for spec in pending + } + for future in as_completed(futures): + shard_id = futures[future] + try: + result = future.result() + completed.add(result.shard_id) + logger.debug( + "Shard %d: %d examples in %.1fs", + result.shard_id, + result.num_examples, + result.generation_time_s, + ) + except Exception: + logger.error("Shard %d failed", shard_id, exc_info=True) + progress.update(1) + checkpoint_counter += 1 + if checkpoint_counter >= self.config.checkpoint_every: + self._update_checkpoint(completed, total_shards) + checkpoint_counter = 0 + + progress.close() + + # Finalize + self._build_index() + stats = self._compute_statistics() + self._write_statistics(stats) + self._print_summary(stats) + + # Remove checkpoint (generation complete) + if self.checkpoint_file.exists(): + self.checkpoint_file.unlink() + + # -- internal helpers -- + + def _plan_shards(self) -> list[ShardSpec]: + """Divide num_assemblies into shards.""" + n = self.config.num_assemblies + size = self.config.shard_size + num_shards = math.ceil(n / size) + shards: list[ShardSpec] = [] + for i in range(num_shards): + start = i * size + count = min(size, n - start) + shards.append( + ShardSpec( + shard_id=i, + start_example_id=start, + count=count, + seed=_derive_shard_seed(self.config.seed, i), + complexity_distribution=dict(self.config.complexity_distribution), + body_count_min=self.config.body_count_min, + body_count_max=self.config.body_count_max, + grounded_ratio=self.config.grounded_ratio, + ) + ) + return shards + + def _find_completed_shards(self) -> set[int]: + """Scan shards directory for existing shard files.""" + completed: set[int] = set() + if not self.shards_dir.exists(): + return completed + + for p in self.shards_dir.iterdir(): + if p.stem.startswith("shard_"): + try: + shard_id = int(p.stem.split("_")[1]) + # Verify file is non-empty + if p.stat().st_size > 0: + completed.add(shard_id) + except (ValueError, IndexError): + pass + return completed + + def _update_checkpoint(self, completed: set[int], total_shards: int) -> None: + """Write checkpoint file.""" + data = { + "completed_shards": sorted(completed), + "total_shards": total_shards, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + with open(self.checkpoint_file, "w") as f: + json.dump(data, f) + + def _build_index(self) -> None: + """Build index.json mapping shard files to assembly ID ranges.""" + shards_info: dict[str, dict[str, int]] = {} + total_assemblies = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + try: + shard_id = int(p.stem.split("_")[1]) + except (ValueError, IndexError): + continue + examples = _load_shard(p) + count = len(examples) + start_id = shard_id * self.config.shard_size + shards_info[p.name] = {"start_id": start_id, "count": count} + total_assemblies += count + + fmt = _shard_format().lstrip(".") + index = { + "format_version": 1, + "total_assemblies": total_assemblies, + "total_shards": len(shards_info), + "shard_format": fmt, + "shards": shards_info, + } + with open(self.index_file, "w") as f: + json.dump(index, f, indent=2) + + def _compute_statistics(self) -> dict[str, Any]: + """Aggregate statistics across all shards.""" + classification_counts: dict[str, int] = {} + body_count_hist: dict[int, int] = {} + joint_type_counts: dict[str, int] = {} + dof_values: list[int] = [] + degeneracy_values: list[int] = [] + rigid_count = 0 + minimally_rigid_count = 0 + total = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + examples = _load_shard(p) + for ex in examples: + total += 1 + cls = str(ex.get("assembly_classification", "unknown")) + classification_counts[cls] = classification_counts.get(cls, 0) + 1 + nb = int(ex.get("n_bodies", 0)) + body_count_hist[nb] = body_count_hist.get(nb, 0) + 1 + for j in ex.get("joints", []): + jt = str(j.get("type", "unknown")) + joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1 + dof_values.append(int(ex.get("internal_dof", 0))) + degeneracy_values.append(int(ex.get("geometric_degeneracies", 0))) + if ex.get("is_rigid"): + rigid_count += 1 + if ex.get("is_minimally_rigid"): + minimally_rigid_count += 1 + + dof_arr = np.array(dof_values) if dof_values else np.zeros(1) + deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1) + + return { + "total_examples": total, + "classification_distribution": dict(sorted(classification_counts.items())), + "body_count_histogram": dict(sorted(body_count_hist.items())), + "joint_type_distribution": dict(sorted(joint_type_counts.items())), + "dof_statistics": { + "mean": float(dof_arr.mean()), + "std": float(dof_arr.std()), + "min": int(dof_arr.min()), + "max": int(dof_arr.max()), + "median": float(np.median(dof_arr)), + }, + "geometric_degeneracy": { + "assemblies_with_degeneracy": int(np.sum(deg_arr > 0)), + "fraction_with_degeneracy": float(np.mean(deg_arr > 0)), + "mean_degeneracies": float(deg_arr.mean()), + }, + "rigidity": { + "rigid_count": rigid_count, + "rigid_fraction": (rigid_count / total if total > 0 else 0.0), + "minimally_rigid_count": minimally_rigid_count, + "minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0), + }, + } + + def _write_statistics(self, stats: dict[str, Any]) -> None: + """Write stats.json.""" + with open(self.stats_file, "w") as f: + json.dump(stats, f, indent=2) + + def _print_summary(self, stats: dict[str, Any]) -> None: + """Print a human-readable summary to stdout.""" + print("\n=== Dataset Generation Summary ===") + print(f"Total examples: {stats['total_examples']}") + print(f"Output directory: {self.output_path}") + print() + print("Classification distribution:") + for cls, count in stats["classification_distribution"].items(): + frac = count / max(stats["total_examples"], 1) * 100 + print(f" {cls}: {count} ({frac:.1f}%)") + print() + print("Joint type distribution:") + for jt, count in stats["joint_type_distribution"].items(): + print(f" {jt}: {count}") + print() + dof = stats["dof_statistics"] + print( + f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]" + ) + rig = stats["rigidity"] + print( + f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} " + f"({rig['rigid_fraction'] * 100:.1f}%) rigid, " + f"{rig['minimally_rigid_count']} minimally rigid" + ) + deg = stats["geometric_degeneracy"] + print( + f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies " + f"({deg['fraction_with_degeneracy'] * 100:.1f}%)" + ) diff --git a/tests/datagen/test_dataset.py b/tests/datagen/test_dataset.py new file mode 100644 index 0000000..398184a --- /dev/null +++ b/tests/datagen/test_dataset.py @@ -0,0 +1,337 @@ +"""Tests for solver.datagen.dataset — dataset generation orchestration.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from typing import TYPE_CHECKING + +from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + _derive_shard_seed, + _parse_scalar, + parse_simple_yaml, +) + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any + + +# --------------------------------------------------------------------------- +# DatasetConfig +# --------------------------------------------------------------------------- + + +class TestDatasetConfig: + """DatasetConfig construction and defaults.""" + + def test_defaults(self) -> None: + cfg = DatasetConfig() + assert cfg.num_assemblies == 100_000 + assert cfg.seed == 42 + assert cfg.shard_size == 1000 + assert cfg.num_workers == 4 + + def test_from_dict_flat(self) -> None: + d: dict[str, Any] = {"num_assemblies": 500, "seed": 123} + cfg = DatasetConfig.from_dict(d) + assert cfg.num_assemblies == 500 + assert cfg.seed == 123 + + def test_from_dict_nested_body_count(self) -> None: + d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 3 + assert cfg.body_count_max == 20 + + def test_from_dict_flat_body_count(self) -> None: + d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 5 + assert cfg.body_count_max == 30 + + def test_from_dict_complexity_distribution(self) -> None: + d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}} + cfg = DatasetConfig.from_dict(d) + assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4} + + def test_from_dict_templates(self) -> None: + d: dict[str, Any] = {"templates": ["chain", "tree"]} + cfg = DatasetConfig.from_dict(d) + assert cfg.templates == ["chain", "tree"] + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +class TestParseScalar: + """_parse_scalar handles different value types.""" + + def test_int(self) -> None: + assert _parse_scalar("42") == 42 + + def test_float(self) -> None: + assert _parse_scalar("3.14") == 3.14 + + def test_bool_true(self) -> None: + assert _parse_scalar("true") is True + + def test_bool_false(self) -> None: + assert _parse_scalar("false") is False + + def test_string(self) -> None: + assert _parse_scalar("hello") == "hello" + + def test_inline_comment(self) -> None: + assert _parse_scalar("0.4 # some comment") == 0.4 + + +class TestParseSimpleYaml: + """parse_simple_yaml handles the synthetic.yaml format.""" + + def test_flat_scalars(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["name"] == "test" + assert result["num"] == 42 + assert result["ratio"] == 0.5 + + def test_nested_dict(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("body_count:\n min: 2\n max: 50\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["body_count"] == {"min": 2, "max": 50} + + def test_list(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["templates"] == ["chain", "tree", "loop"] + + def test_inline_comments(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("dist:\n simple: 0.4 # comment\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["dist"]["simple"] == 0.4 + + def test_synthetic_yaml(self) -> None: + """Parse the actual project config.""" + result = parse_simple_yaml("configs/dataset/synthetic.yaml") + assert result["name"] == "synthetic" + assert result["num_assemblies"] == 100000 + assert isinstance(result["complexity_distribution"], dict) + assert isinstance(result["templates"], list) + assert result["shard_size"] == 1000 + + +# --------------------------------------------------------------------------- +# Shard seed derivation +# --------------------------------------------------------------------------- + + +class TestShardSeedDerivation: + """_derive_shard_seed is deterministic and unique per shard.""" + + def test_deterministic(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 0) + assert s1 == s2 + + def test_different_shards(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 1) + assert s1 != s2 + + def test_different_global_seeds(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(99, 0) + assert s1 != s2 + + +# --------------------------------------------------------------------------- +# DatasetGenerator — small end-to-end tests +# --------------------------------------------------------------------------- + + +class TestDatasetGenerator: + """End-to-end tests with small datasets.""" + + def test_small_generation(self, tmp_path: Path) -> None: + """Generate 10 examples in a single shard.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + assert shards_dir.exists() + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 1 + + index_file = tmp_path / "output" / "index.json" + assert index_file.exists() + index = json.loads(index_file.read_text()) + assert index["total_assemblies"] == 10 + + stats_file = tmp_path / "output" / "stats.json" + assert stats_file.exists() + + def test_multi_shard(self, tmp_path: Path) -> None: + """Generate 20 examples across 2 shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 2 + + def test_resume_skips_completed(self, tmp_path: Path) -> None: + """Resume skips already-completed shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + # Record shard modification times + shards_dir = tmp_path / "output" / "shards" + mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")} + + # Remove stats (simulate incomplete) and re-run + (tmp_path / "output" / "stats.json").unlink() + + DatasetGenerator(cfg).run() + + # Shards should NOT have been regenerated + for p in shards_dir.glob("shard_*"): + assert p.stat().st_mtime == mtimes[p.name] + + # Stats should be regenerated + assert (tmp_path / "output" / "stats.json").exists() + + def test_checkpoint_removed(self, tmp_path: Path) -> None: + """Checkpoint file is cleaned up after completion.""" + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / "output"), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + checkpoint = tmp_path / "output" / ".checkpoint.json" + assert not checkpoint.exists() + + def test_stats_structure(self, tmp_path: Path) -> None: + """stats.json has expected top-level keys.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + stats = json.loads((tmp_path / "output" / "stats.json").read_text()) + assert stats["total_examples"] == 10 + assert "classification_distribution" in stats + assert "body_count_histogram" in stats + assert "joint_type_distribution" in stats + assert "dof_statistics" in stats + assert "geometric_degeneracy" in stats + assert "rigidity" in stats + + def test_index_structure(self, tmp_path: Path) -> None: + """index.json has expected format.""" + cfg = DatasetConfig( + num_assemblies=15, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + index = json.loads((tmp_path / "output" / "index.json").read_text()) + assert index["format_version"] == 1 + assert index["total_assemblies"] == 15 + assert index["total_shards"] == 2 + assert "shards" in index + for _name, info in index["shards"].items(): + assert "start_id" in info + assert "count" in info + + def test_deterministic_output(self, tmp_path: Path) -> None: + """Same seed produces same results.""" + for run_dir in ("run1", "run2"): + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / run_dir), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + s1 = json.loads((tmp_path / "run1" / "stats.json").read_text()) + s2 = json.loads((tmp_path / "run2" / "stats.json").read_text()) + assert s1["total_examples"] == s2["total_examples"] + assert s1["classification_distribution"] == s2["classification_distribution"] + + +# --------------------------------------------------------------------------- +# CLI integration test +# --------------------------------------------------------------------------- + + +class TestCLI: + """Run the script via subprocess.""" + + def test_argparse_mode(self, tmp_path: Path) -> None: + result = subprocess.run( + [ + sys.executable, + "scripts/generate_synthetic.py", + "--num-assemblies", + "5", + "--output-dir", + str(tmp_path / "cli_out"), + "--shard-size", + "5", + "--num-workers", + "1", + "--seed", + "42", + ], + capture_output=True, + text=True, + cwd="/home/developer", + timeout=120, + env={**os.environ, "PYTHONPATH": "/home/developer"}, + ) + assert result.returncode == 0, ( + f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) + assert (tmp_path / "cli_out" / "index.json").exists() + assert (tmp_path / "cli_out" / "stats.json").exists()