Merge remote-tracking branch 'public/main'
Some checks failed
CI / lint (push) Successful in 38s
CI / type-check (push) Successful in 1m47s
CI / test (push) Failing after 3m2s

# Conflicts:
#	.gitignore
#	README.md
This commit is contained in:
2026-02-03 10:53:48 -06:00
49 changed files with 6045 additions and 24 deletions

65
.gitea/workflows/ci.yaml Normal file
View File

@@ -0,0 +1,65 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install ruff mypy
pip install -e ".[dev]" || pip install ruff mypy numpy
- name: Ruff check
run: ruff check solver/ freecad/ tests/ scripts/
- name: Ruff format check
run: ruff format --check solver/ freecad/ tests/ scripts/
type-check:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install mypy numpy
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install torch-geometric
pip install -e ".[dev]"
- name: Mypy
run: mypy solver/ freecad/
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.11"
- name: Install dependencies
run: |
pip install torch --index-url https://download.pytorch.org/whl/cpu
pip install torch-geometric
pip install -e ".[train,dev]"
- name: Run tests
run: pytest tests/ freecad/tests/ -v --tb=short

77
.gitignore vendored
View File

@@ -1,44 +1,83 @@
# Prerequisites
# C++ compiled objects
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
# C++ libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
# C++ executables
*.exe
*.out
*.app
.vs
# C++ build
build/
cmake-build-debug/
.vs/
x64/
temp/
# OndselSolver test artifacts
*.bak
assembly.asmt
build
cmake-build-debug
.idea
temp/
/testapp/draggingBackhoe.log
/testapp/runPreDragBackhoe.asmt
# Python
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
dist/
*.egg
# Virtual environments
.venv/
venv/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# mypy / ruff / pytest
.mypy_cache/
.ruff_cache/
.pytest_cache/
# Data (large files tracked separately)
data/synthetic/*.pt
data/fusion360/*.json
data/fusion360/*.step
data/processed/*.pt
!data/**/.gitkeep
# Model checkpoints
*.ckpt
*.pth
*.onnx
*.torchscript
# Experiment tracking
wandb/
runs/
# OS
.DS_Store
Thumbs.db
# Environment
.env

23
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,23 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies:
- torch>=2.2
- numpy>=1.26
args: [--ignore-missing-imports]
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v3.1.0
hooks:
- id: conventional-pre-commit
stages: [commit-msg]
args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert]

61
Dockerfile Normal file
View File

@@ -0,0 +1,61 @@
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.11 python3.11-venv python3.11-dev python3-pip \
git wget curl \
# FreeCAD headless deps
freecad \
libgl1-mesa-glx libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
# Create venv
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Install PyTorch with CUDA
RUN pip install --no-cache-dir \
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install PyG
RUN pip install --no-cache-dir \
torch-geometric \
pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
WORKDIR /workspace
# Install project
COPY pyproject.toml .
RUN pip install --no-cache-dir -e ".[train,dev]" || true
COPY . .
RUN pip install --no-cache-dir -e ".[train,dev]"
# -------------------------------------------------------------------
FROM base AS cpu
# CPU-only variant (for CI and non-GPU environments)
FROM python:3.11-slim AS cpu-only
ENV PYTHONUNBUFFERED=1
RUN apt-get update && apt-get install -y --no-install-recommends \
git freecad libgl1-mesa-glx libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY pyproject.toml .
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache-dir torch-geometric
COPY . .
RUN pip install --no-cache-dir -e ".[train,dev]"
CMD ["pytest", "tests/", "-v"]

48
Makefile Normal file
View File

@@ -0,0 +1,48 @@
.PHONY: train test lint data-gen export format type-check install dev clean help
PYTHON ?= python
PYTEST ?= pytest
RUFF ?= ruff
MYPY ?= mypy
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
install: ## Install core dependencies
pip install -e .
dev: ## Install all dependencies including dev tools
pip install -e ".[train,dev]"
pre-commit install
pre-commit install --hook-type commit-msg
train: ## Run training (pass CONFIG=path/to/config.yaml)
$(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG))
test: ## Run test suite
$(PYTEST) tests/ freecad/tests/ -v --tb=short
lint: ## Run ruff linter
$(RUFF) check solver/ freecad/ tests/ scripts/
format: ## Format code with ruff
$(RUFF) format solver/ freecad/ tests/ scripts/
$(RUFF) check --fix solver/ freecad/ tests/ scripts/
type-check: ## Run mypy type checker
$(MYPY) solver/ freecad/
data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml)
$(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG))
export: ## Export trained model for deployment
$(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL))
clean: ## Remove build artifacts and caches
rm -rf build/ dist/ *.egg-info/
rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
check: lint type-check test ## Run all checks (lint, type-check, test)

View File

@@ -1,7 +1,91 @@
# MbDCode
Assembly Constraints and Multibody Dynamics code
# Kindred Solver
Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited)
Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer.
The MbD theory is at
https://github.com/Ondsel-Development/MbDTheory
## Components
### OndselSolver (C++)
Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver.
- Source: `OndselSolver/`
- Entry point: `OndselSolverMain/`
- Tests: `tests/`, `testapp/`
- Build: CMake
**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory)
#### Building
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
### ML Solver Layer (Python)
Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns.
- Core library: `solver/`
- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling)
- Model architectures: `solver/models/` (GIN, GAT, NNConv)
- Training: `solver/training/`
- Inference: `solver/inference/`
- FreeCAD integration: `freecad/`
- Configuration: `configs/` (Hydra)
#### Setup
```bash
pip install -e ".[train,dev]"
pre-commit install
```
#### Usage
```bash
make help # show all targets
make dev # install all deps + pre-commit hooks
make test # run tests
make lint # run ruff linter
make check # lint + type-check + test
make data-gen # generate synthetic data
make train # run training
make export # export model
```
Docker is also supported:
```bash
docker compose up train # GPU training
docker compose up test # run tests
docker compose up data-gen # generate synthetic data
```
## Repository structure
```
kindred-solver/
├── OndselSolver/ # C++ numerical solver library
├── OndselSolverMain/ # C++ solver CLI entry point
├── tests/ # C++ unit tests + Python tests
├── testapp/ # C++ test application
├── solver/ # Python ML solver library
│ ├── datagen/ # Synthetic data generation (pebble game)
│ ├── datasets/ # PyG dataset adapters
│ ├── models/ # GNN architectures
│ ├── training/ # Training loops
│ ├── evaluation/ # Metrics and visualization
│ └── inference/ # Runtime prediction API
├── freecad/ # FreeCAD workbench integration
├── configs/ # Hydra configs (dataset, model, training, export)
├── scripts/ # CLI utilities
├── data/ # Datasets (not committed)
├── export/ # Model packaging
└── docs/ # Documentation
```
## License
OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE))
ML Solver Layer: Apache-2.0

View File

@@ -0,0 +1,12 @@
# Fusion 360 Gallery dataset config
name: fusion360
data_dir: data/fusion360
output_dir: data/processed
splits:
train: 0.8
val: 0.1
test: 0.1
stratify_by: complexity
seed: 42

View File

@@ -0,0 +1,26 @@
# Synthetic dataset generation config
name: synthetic
num_assemblies: 100000
output_dir: data/synthetic
shard_size: 1000
complexity_distribution:
simple: 0.4 # 2-5 bodies
medium: 0.4 # 6-15 bodies
complex: 0.2 # 16-50 bodies
body_count:
min: 2
max: 50
templates:
- chain
- tree
- loop
- star
- mixed
grounded_ratio: 0.5
seed: 42
num_workers: 4
checkpoint_every: 5

View File

@@ -0,0 +1,25 @@
# Production model export config
model_checkpoint: checkpoints/finetune/best_val_loss.ckpt
output_dir: export/
formats:
onnx:
enabled: true
opset_version: 17
dynamic_axes: true
torchscript:
enabled: true
model_card:
version: "0.1.0"
architecture: baseline
training_data:
- synthetic_100k
- fusion360_gallery
size_budget_mb: 50
inference:
device: cpu
batch_size: 1
confidence_threshold: 0.8

View File

@@ -0,0 +1,24 @@
# Baseline GIN model config
name: baseline
architecture: gin
encoder:
num_layers: 3
hidden_dim: 128
dropout: 0.1
node_features_dim: 22
edge_features_dim: 22
heads:
edge_classification:
enabled: true
hidden_dim: 64
graph_classification:
enabled: true
num_classes: 4 # rigid, under, over, mixed
joint_type:
enabled: true
num_classes: 12
dof_regression:
enabled: true

28
configs/model/gat.yaml Normal file
View File

@@ -0,0 +1,28 @@
# Advanced GAT model config
name: gat_solver
architecture: gat
encoder:
num_layers: 4
hidden_dim: 256
num_heads: 8
dropout: 0.1
residual: true
node_features_dim: 22
edge_features_dim: 22
heads:
edge_classification:
enabled: true
hidden_dim: 128
graph_classification:
enabled: true
num_classes: 4
joint_type:
enabled: true
num_classes: 12
dof_regression:
enabled: true
dof_tracking:
enabled: true

View File

@@ -0,0 +1,45 @@
# Fine-tuning on real data config
phase: finetune
dataset: fusion360
model: baseline
pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt
optimizer:
name: adamw
lr: 1e-5
weight_decay: 1e-4
scheduler:
name: cosine_annealing
T_max: 50
eta_min: 1e-7
training:
epochs: 50
batch_size: 32
gradient_clip: 1.0
early_stopping_patience: 10
amp: true
freeze_encoder: false # set true for frozen encoder experiment
loss:
edge_weight: 1.0
graph_weight: 0.5
joint_type_weight: 0.3
dof_weight: 0.2
redundant_penalty: 2.0
checkpointing:
save_best_val_loss: true
save_best_val_accuracy: true
save_every_n_epochs: 5
checkpoint_dir: checkpoints/finetune
logging:
backend: wandb
project: kindred-solver
log_every_n_steps: 20
seed: 42

View File

@@ -0,0 +1,42 @@
# Synthetic pre-training config
phase: pretrain
dataset: synthetic
model: baseline
optimizer:
name: adamw
lr: 1e-3
weight_decay: 1e-4
scheduler:
name: cosine_annealing
T_max: 100
eta_min: 1e-6
training:
epochs: 100
batch_size: 64
gradient_clip: 1.0
early_stopping_patience: 10
amp: true
loss:
edge_weight: 1.0
graph_weight: 0.5
joint_type_weight: 0.3
dof_weight: 0.2
redundant_penalty: 2.0 # safety loss multiplier
checkpointing:
save_best_val_loss: true
save_best_val_accuracy: true
save_every_n_epochs: 10
checkpoint_dir: checkpoints/pretrain
logging:
backend: wandb # or tensorboard
project: kindred-solver
log_every_n_steps: 50
seed: 42

0
data/fusion360/.gitkeep Normal file
View File

0
data/processed/.gitkeep Normal file
View File

0
data/splits/.gitkeep Normal file
View File

0
data/synthetic/.gitkeep Normal file
View File

39
docker-compose.yml Normal file
View File

@@ -0,0 +1,39 @@
services:
train:
build:
context: .
dockerfile: Dockerfile
target: base
volumes:
- .:/workspace
- ./data:/workspace/data
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command: make train
environment:
- CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
- WANDB_API_KEY=${WANDB_API_KEY:-}
test:
build:
context: .
dockerfile: Dockerfile
target: cpu-only
volumes:
- .:/workspace
command: make check
data-gen:
build:
context: .
dockerfile: Dockerfile
target: base
volumes:
- .:/workspace
- ./data:/workspace/data
command: make data-gen

0
docs/.gitkeep Normal file
View File

0
export/.gitkeep Normal file
View File

0
freecad/__init__.py Normal file
View File

View File

View File

View File

97
pyproject.toml Normal file
View File

@@ -0,0 +1,97 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "kindred-solver"
version = "0.1.0"
description = "Assembly constraint prediction via GNN for Kindred Create"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.11"
authors = [
{ name = "Kindred Systems" },
]
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
]
dependencies = [
"torch>=2.2",
"torch-geometric>=2.5",
"numpy>=1.26",
"scipy>=1.12",
]
[project.optional-dependencies]
train = [
"wandb>=0.16",
"tensorboard>=2.16",
"hydra-core>=1.3",
"omegaconf>=2.3",
"matplotlib>=3.8",
"networkx>=3.2",
]
freecad = [
"pyside6>=6.6",
]
dev = [
"pytest>=8.0",
"pytest-cov>=4.1",
"ruff>=0.3",
"mypy>=1.8",
"pre-commit>=3.6",
]
[project.urls]
Repository = "https://git.kindred-systems.com/kindred/solver"
[tool.hatch.build.targets.wheel]
packages = ["solver", "freecad"]
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"RUF", # ruff-specific
]
[tool.ruff.lint.isort]
known-first-party = ["solver", "freecad"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
check_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"torch.*",
"torch_geometric.*",
"scipy.*",
"wandb.*",
"hydra.*",
"omegaconf.*",
]
ignore_missing_imports = true
[tool.pytest.ini_options]
testpaths = ["tests", "freecad/tests"]
addopts = "-v --tb=short"

View File

@@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""Generate synthetic assembly dataset for kindred-solver training.
Usage (argparse fallback — always available)::
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
Usage (Hydra — when hydra-core is installed)::
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
"""
from __future__ import annotations
def _try_hydra_main() -> bool:
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
try:
import hydra # type: ignore[import-untyped]
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
except ImportError:
return False
@hydra.main(
config_path="../configs/dataset",
config_name="synthetic",
version_base=None,
)
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
config_dict = OmegaConf.to_container(cfg, resolve=True)
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
_run() # type: ignore[no-untyped-call]
return True
def _argparse_main() -> None:
"""Fallback CLI using argparse."""
import argparse
parser = argparse.ArgumentParser(
description="Generate synthetic assembly dataset",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to YAML config file (optional)",
)
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
parser.add_argument(
"--checkpoint-every",
type=int,
default=None,
help="Checkpoint interval (shards)",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Do not resume from existing checkpoints",
)
args = parser.parse_args()
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
parse_simple_yaml,
)
config_dict: dict[str, object] = {}
if args.config:
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
# CLI args override config file (only when explicitly provided)
_override_map = {
"num_assemblies": args.num_assemblies,
"output_dir": args.output_dir,
"shard_size": args.shard_size,
"body_count_min": args.body_count_min,
"body_count_max": args.body_count_max,
"grounded_ratio": args.grounded_ratio,
"seed": args.seed,
"num_workers": args.num_workers,
"checkpoint_every": args.checkpoint_every,
}
for key, val in _override_map.items():
if val is not None:
config_dict[key] = val
if args.no_resume:
config_dict["resume"] = False
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
def main() -> None:
"""Entry point: try Hydra first, fall back to argparse."""
if not _try_hydra_main():
_argparse_main()
if __name__ == "__main__":
main()

0
solver/__init__.py Normal file
View File

View File

@@ -0,0 +1,37 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.analysis import analyze_assembly
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
from solver.datagen.generator import (
COMPLEXITY_RANGES,
AxisStrategy,
SyntheticAssemblyGenerator,
)
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.labeling import AssemblyLabels, label_assembly
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
PebbleState,
RigidBody,
)
__all__ = [
"COMPLEXITY_RANGES",
"AssemblyLabels",
"AxisStrategy",
"ConstraintAnalysis",
"DatasetConfig",
"DatasetGenerator",
"JacobianVerifier",
"Joint",
"JointType",
"PebbleGame3D",
"PebbleState",
"RigidBody",
"SyntheticAssemblyGenerator",
"analyze_assembly",
"label_assembly",
]

140
solver/datagen/analysis.py Normal file
View File

@@ -0,0 +1,140 @@
"""Combined pebble game + Jacobian verification analysis.
Provides :func:`analyze_assembly`, the main entry point for full rigidity
analysis of an assembly using both combinatorial and numerical methods.
"""
from __future__ import annotations
import numpy as np
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
__all__ = ["analyze_assembly"]
_GROUND_ID = -1
def analyze_assembly(
bodies: list[RigidBody],
joints: list[Joint],
ground_body: int | None = None,
) -> ConstraintAnalysis:
"""Full rigidity analysis of an assembly using both methods.
Args:
bodies: List of rigid bodies in the assembly.
joints: List of joints connecting bodies.
ground_body: If set, this body is fixed (adds 6 implicit constraints).
Returns:
ConstraintAnalysis with combinatorial and numerical results.
"""
# --- Pebble Game ---
pg = PebbleGame3D()
all_edge_results = []
# Add a virtual ground body (id=-1) if grounding is requested.
# Grounding body X means adding a fixed joint between X and
# the virtual ground. This properly lets the pebble game account
# for the 6 removed DOF without breaking invariants.
if ground_body is not None:
pg.add_body(_GROUND_ID)
for body in bodies:
pg.add_body(body.body_id)
if ground_body is not None:
ground_joint = Joint(
joint_id=-1,
body_a=ground_body,
body_b=_GROUND_ID,
joint_type=JointType.FIXED,
anchor_a=bodies[0].position if bodies else np.zeros(3),
anchor_b=bodies[0].position if bodies else np.zeros(3),
)
pg.add_joint(ground_joint)
# Don't include ground joint edges in the output labels
# (they're infrastructure, not user constraints)
for joint in joints:
results = pg.add_joint(joint)
all_edge_results.extend(results)
combinatorial_independent = len(pg.state.independent_edges)
grounded = ground_body is not None
# The virtual ground body contributes 6 pebbles to the total.
# Subtract those from the reported DOF for user-facing numbers.
raw_dof = pg.get_dof()
ground_offset = 6 if grounded else 0
effective_dof = raw_dof - ground_offset
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
# Classify based on effective (adjusted) DOF, not raw pebble game output,
# because the virtual ground body skews the raw numbers.
redundant = pg.get_redundant_count()
if redundant > 0 and effective_internal_dof > 0:
combinatorial_classification = "mixed"
elif redundant > 0:
combinatorial_classification = "overconstrained"
elif effective_internal_dof > 0:
combinatorial_classification = "underconstrained"
else:
combinatorial_classification = "well-constrained"
# --- Jacobian Verification ---
verifier = JacobianVerifier(bodies)
for joint in joints:
verifier.add_joint_constraints(joint)
# If grounded, remove the ground body's columns (fix its DOF)
j = verifier.get_jacobian()
if ground_body is not None and j.size > 0:
idx = verifier.body_index[ground_body]
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
j = np.delete(j, cols_to_remove, axis=1)
if j.size > 0:
sv = np.linalg.svd(j, compute_uv=False)
jacobian_rank = int(np.sum(sv > 1e-8))
else:
jacobian_rank = 0
n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies)
jacobian_nullity = n_cols - jacobian_rank
dependent = verifier.find_dependencies()
# Adjust for ground
trivial_dof = 0 if ground_body is not None else 6
jacobian_internal_dof = jacobian_nullity - trivial_dof
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
# Rigidity: numerically rigid if nullity == trivial DOF
is_rigid = jacobian_nullity <= trivial_dof
is_minimally_rigid = is_rigid and len(dependent) == 0
return ConstraintAnalysis(
combinatorial_dof=effective_dof,
combinatorial_internal_dof=effective_internal_dof,
combinatorial_redundant=pg.get_redundant_count(),
combinatorial_classification=combinatorial_classification,
per_edge_results=all_edge_results,
jacobian_rank=jacobian_rank,
jacobian_nullity=jacobian_nullity,
jacobian_internal_dof=max(0, jacobian_internal_dof),
numerically_dependent=dependent,
geometric_degeneracies=geometric_degeneracies,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
)

624
solver/datagen/dataset.py Normal file
View File

@@ -0,0 +1,624 @@
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
for parallel generation of synthetic assembly training data.
"""
from __future__ import annotations
import hashlib
import json
import logging
import math
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"DatasetConfig",
"DatasetGenerator",
"parse_simple_yaml",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class DatasetConfig:
"""Configuration for synthetic dataset generation."""
name: str = "synthetic"
num_assemblies: int = 100_000
output_dir: str = "data/synthetic"
shard_size: int = 1000
complexity_distribution: dict[str, float] = field(
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
)
body_count_min: int = 2
body_count_max: int = 50
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
grounded_ratio: float = 0.5
seed: int = 42
num_workers: int = 4
checkpoint_every: int = 5
resume: bool = True
@classmethod
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
Handles both flat keys (``body_count_min``) and nested forms
(``body_count: {min: 2, max: 50}``).
"""
kw: dict[str, Any] = {}
for key in (
"name",
"num_assemblies",
"output_dir",
"shard_size",
"grounded_ratio",
"seed",
"num_workers",
"checkpoint_every",
"resume",
):
if key in d:
kw[key] = d[key]
# Handle nested body_count dict
if "body_count" in d and isinstance(d["body_count"], dict):
bc = d["body_count"]
if "min" in bc:
kw["body_count_min"] = int(bc["min"])
if "max" in bc:
kw["body_count_max"] = int(bc["max"])
else:
if "body_count_min" in d:
kw["body_count_min"] = int(d["body_count_min"])
if "body_count_max" in d:
kw["body_count_max"] = int(d["body_count_max"])
if "complexity_distribution" in d:
cd = d["complexity_distribution"]
if isinstance(cd, dict):
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
if "templates" in d and isinstance(d["templates"], list):
kw["templates"] = [str(t) for t in d["templates"]]
return cls(**kw)
# ---------------------------------------------------------------------------
# Shard specification / result
# ---------------------------------------------------------------------------
@dataclass
class ShardSpec:
"""Specification for generating a single shard."""
shard_id: int
start_example_id: int
count: int
seed: int
complexity_distribution: dict[str, float]
body_count_min: int
body_count_max: int
grounded_ratio: float
@dataclass
class ShardResult:
"""Result returned from a shard worker."""
shard_id: int
num_examples: int
file_path: str
generation_time_s: float
# ---------------------------------------------------------------------------
# Seed derivation
# ---------------------------------------------------------------------------
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
"""Derive a deterministic per-shard seed from the global seed."""
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
return int(h[:8], 16)
# ---------------------------------------------------------------------------
# Progress display
# ---------------------------------------------------------------------------
class _PrintProgress:
"""Fallback progress display when tqdm is unavailable."""
def __init__(self, total: int) -> None:
self.total = total
self.current = 0
self.start_time = time.monotonic()
def update(self, n: int = 1) -> None:
self.current += n
elapsed = time.monotonic() - self.start_time
rate = self.current / elapsed if elapsed > 0 else 0.0
eta = (self.total - self.current) / rate if rate > 0 else 0.0
pct = 100.0 * self.current / self.total
sys.stdout.write(
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
)
sys.stdout.flush()
def close(self) -> None:
sys.stdout.write("\n")
sys.stdout.flush()
def _make_progress(total: int) -> _PrintProgress:
"""Create a progress tracker (tqdm if available, else print-based)."""
try:
from tqdm import tqdm # type: ignore[import-untyped]
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
except ImportError:
return _PrintProgress(total)
# ---------------------------------------------------------------------------
# Shard I/O
# ---------------------------------------------------------------------------
def _save_shard(
shard_id: int,
examples: list[dict[str, Any]],
shards_dir: Path,
) -> Path:
"""Save a shard to disk (.pt if torch available, else .json)."""
shards_dir.mkdir(parents=True, exist_ok=True)
try:
import torch # type: ignore[import-untyped]
path = shards_dir / f"shard_{shard_id:05d}.pt"
torch.save(examples, path)
except ImportError:
path = shards_dir / f"shard_{shard_id:05d}.json"
with open(path, "w") as f:
json.dump(examples, f)
return path
def _load_shard(path: Path) -> list[dict[str, Any]]:
"""Load a shard from disk (.pt or .json)."""
if path.suffix == ".pt":
import torch # type: ignore[import-untyped]
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
return result
with open(path) as f:
result = json.load(f)
return result
def _shard_format() -> str:
"""Return the shard file extension based on available libraries."""
try:
import torch # type: ignore[import-untyped] # noqa: F401
return ".pt"
except ImportError:
return ".json"
# ---------------------------------------------------------------------------
# Shard worker (module-level for pickling)
# ---------------------------------------------------------------------------
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
from solver.datagen.generator import SyntheticAssemblyGenerator
t0 = time.monotonic()
gen = SyntheticAssemblyGenerator(seed=spec.seed)
rng = np.random.default_rng(spec.seed + 1)
tiers = list(spec.complexity_distribution.keys())
probs_list = [spec.complexity_distribution[t] for t in tiers]
total = sum(probs_list)
probs = [p / total for p in probs_list]
examples: list[dict[str, Any]] = []
for i in range(spec.count):
tier_idx = int(rng.choice(len(tiers), p=probs))
tier = tiers[tier_idx]
try:
batch = gen.generate_training_batch(
batch_size=1,
complexity_tier=tier, # type: ignore[arg-type]
grounded_ratio=spec.grounded_ratio,
)
ex = batch[0]
ex["example_id"] = spec.start_example_id + i
ex["complexity_tier"] = tier
examples.append(ex)
except Exception:
logger.warning(
"Shard %d, example %d failed — skipping",
spec.shard_id,
i,
exc_info=True,
)
shards_dir = Path(output_dir) / "shards"
path = _save_shard(spec.shard_id, examples, shards_dir)
elapsed = time.monotonic() - t0
return ShardResult(
shard_id=spec.shard_id,
num_examples=len(examples),
file_path=str(path),
generation_time_s=elapsed,
)
# ---------------------------------------------------------------------------
# Minimal YAML parser
# ---------------------------------------------------------------------------
def _parse_scalar(value: str) -> int | float | bool | str:
"""Parse a YAML scalar value."""
# Strip inline comments (space + #)
if " #" in value:
value = value[: value.index(" #")].strip()
elif " #" in value:
value = value[: value.index(" #")].strip()
v = value.strip()
if v.lower() in ("true", "yes"):
return True
if v.lower() in ("false", "no"):
return False
try:
return int(v)
except ValueError:
pass
try:
return float(v)
except ValueError:
pass
return v.strip("'\"")
def parse_simple_yaml(path: str) -> dict[str, Any]:
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
This is **not** a full YAML parser. It handles the structure of
``configs/dataset/synthetic.yaml``.
"""
result: dict[str, Any] = {}
current_key: str | None = None
with open(path) as f:
for raw_line in f:
line = raw_line.rstrip()
# Skip blank lines and full-line comments
if not line or line.lstrip().startswith("#"):
continue
indent = len(line) - len(line.lstrip())
if indent == 0 and ":" in line:
key, _, value = line.partition(":")
key = key.strip()
value = value.strip()
if value:
result[key] = _parse_scalar(value)
current_key = None
else:
current_key = key
result[key] = {}
continue
if indent > 0 and line.lstrip().startswith("- "):
item = line.lstrip()[2:].strip()
if current_key is not None:
if isinstance(result.get(current_key), dict) and not result[current_key]:
result[current_key] = []
if isinstance(result.get(current_key), list):
result[current_key].append(_parse_scalar(item))
continue
if indent > 0 and ":" in line and current_key is not None:
k, _, v = line.partition(":")
k = k.strip()
v = v.strip()
if v:
# Strip inline comments
if " #" in v:
v = v[: v.index(" #")].strip()
if not isinstance(result.get(current_key), dict):
result[current_key] = {}
result[current_key][k] = _parse_scalar(v)
continue
return result
# ---------------------------------------------------------------------------
# Dataset generator orchestrator
# ---------------------------------------------------------------------------
class DatasetGenerator:
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
def __init__(self, config: DatasetConfig) -> None:
self.config = config
self.output_path = Path(config.output_dir)
self.shards_dir = self.output_path / "shards"
self.checkpoint_file = self.output_path / ".checkpoint.json"
self.index_file = self.output_path / "index.json"
self.stats_file = self.output_path / "stats.json"
# -- public API --
def run(self) -> None:
"""Generate the full dataset."""
self.output_path.mkdir(parents=True, exist_ok=True)
self.shards_dir.mkdir(parents=True, exist_ok=True)
shards = self._plan_shards()
total_shards = len(shards)
# Resume: find already-completed shards
completed: set[int] = set()
if self.config.resume:
completed = self._find_completed_shards()
pending = [s for s in shards if s.shard_id not in completed]
if not pending:
logger.info("All %d shards already complete.", total_shards)
else:
logger.info(
"Generating %d shards (%d already complete).",
len(pending),
len(completed),
)
progress = _make_progress(len(pending))
workers = max(1, self.config.num_workers)
checkpoint_counter = 0
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(
_generate_shard_worker,
spec,
str(self.output_path),
): spec.shard_id
for spec in pending
}
for future in as_completed(futures):
shard_id = futures[future]
try:
result = future.result()
completed.add(result.shard_id)
logger.debug(
"Shard %d: %d examples in %.1fs",
result.shard_id,
result.num_examples,
result.generation_time_s,
)
except Exception:
logger.error("Shard %d failed", shard_id, exc_info=True)
progress.update(1)
checkpoint_counter += 1
if checkpoint_counter >= self.config.checkpoint_every:
self._update_checkpoint(completed, total_shards)
checkpoint_counter = 0
progress.close()
# Finalize
self._build_index()
stats = self._compute_statistics()
self._write_statistics(stats)
self._print_summary(stats)
# Remove checkpoint (generation complete)
if self.checkpoint_file.exists():
self.checkpoint_file.unlink()
# -- internal helpers --
def _plan_shards(self) -> list[ShardSpec]:
"""Divide num_assemblies into shards."""
n = self.config.num_assemblies
size = self.config.shard_size
num_shards = math.ceil(n / size)
shards: list[ShardSpec] = []
for i in range(num_shards):
start = i * size
count = min(size, n - start)
shards.append(
ShardSpec(
shard_id=i,
start_example_id=start,
count=count,
seed=_derive_shard_seed(self.config.seed, i),
complexity_distribution=dict(self.config.complexity_distribution),
body_count_min=self.config.body_count_min,
body_count_max=self.config.body_count_max,
grounded_ratio=self.config.grounded_ratio,
)
)
return shards
def _find_completed_shards(self) -> set[int]:
"""Scan shards directory for existing shard files."""
completed: set[int] = set()
if not self.shards_dir.exists():
return completed
for p in self.shards_dir.iterdir():
if p.stem.startswith("shard_"):
try:
shard_id = int(p.stem.split("_")[1])
# Verify file is non-empty
if p.stat().st_size > 0:
completed.add(shard_id)
except (ValueError, IndexError):
pass
return completed
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
"""Write checkpoint file."""
data = {
"completed_shards": sorted(completed),
"total_shards": total_shards,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
with open(self.checkpoint_file, "w") as f:
json.dump(data, f)
def _build_index(self) -> None:
"""Build index.json mapping shard files to assembly ID ranges."""
shards_info: dict[str, dict[str, int]] = {}
total_assemblies = 0
for p in sorted(self.shards_dir.iterdir()):
if not p.stem.startswith("shard_"):
continue
try:
shard_id = int(p.stem.split("_")[1])
except (ValueError, IndexError):
continue
examples = _load_shard(p)
count = len(examples)
start_id = shard_id * self.config.shard_size
shards_info[p.name] = {"start_id": start_id, "count": count}
total_assemblies += count
fmt = _shard_format().lstrip(".")
index = {
"format_version": 1,
"total_assemblies": total_assemblies,
"total_shards": len(shards_info),
"shard_format": fmt,
"shards": shards_info,
}
with open(self.index_file, "w") as f:
json.dump(index, f, indent=2)
def _compute_statistics(self) -> dict[str, Any]:
"""Aggregate statistics across all shards."""
classification_counts: dict[str, int] = {}
body_count_hist: dict[int, int] = {}
joint_type_counts: dict[str, int] = {}
dof_values: list[int] = []
degeneracy_values: list[int] = []
rigid_count = 0
minimally_rigid_count = 0
total = 0
for p in sorted(self.shards_dir.iterdir()):
if not p.stem.startswith("shard_"):
continue
examples = _load_shard(p)
for ex in examples:
total += 1
cls = str(ex.get("assembly_classification", "unknown"))
classification_counts[cls] = classification_counts.get(cls, 0) + 1
nb = int(ex.get("n_bodies", 0))
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
for j in ex.get("joints", []):
jt = str(j.get("type", "unknown"))
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
dof_values.append(int(ex.get("internal_dof", 0)))
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
if ex.get("is_rigid"):
rigid_count += 1
if ex.get("is_minimally_rigid"):
minimally_rigid_count += 1
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
return {
"total_examples": total,
"classification_distribution": dict(sorted(classification_counts.items())),
"body_count_histogram": dict(sorted(body_count_hist.items())),
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
"dof_statistics": {
"mean": float(dof_arr.mean()),
"std": float(dof_arr.std()),
"min": int(dof_arr.min()),
"max": int(dof_arr.max()),
"median": float(np.median(dof_arr)),
},
"geometric_degeneracy": {
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
"mean_degeneracies": float(deg_arr.mean()),
},
"rigidity": {
"rigid_count": rigid_count,
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
"minimally_rigid_count": minimally_rigid_count,
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
},
}
def _write_statistics(self, stats: dict[str, Any]) -> None:
"""Write stats.json."""
with open(self.stats_file, "w") as f:
json.dump(stats, f, indent=2)
def _print_summary(self, stats: dict[str, Any]) -> None:
"""Print a human-readable summary to stdout."""
print("\n=== Dataset Generation Summary ===")
print(f"Total examples: {stats['total_examples']}")
print(f"Output directory: {self.output_path}")
print()
print("Classification distribution:")
for cls, count in stats["classification_distribution"].items():
frac = count / max(stats["total_examples"], 1) * 100
print(f" {cls}: {count} ({frac:.1f}%)")
print()
print("Joint type distribution:")
for jt, count in stats["joint_type_distribution"].items():
print(f" {jt}: {count}")
print()
dof = stats["dof_statistics"]
print(
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
)
rig = stats["rigidity"]
print(
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
f"{rig['minimally_rigid_count']} minimally rigid"
)
deg = stats["geometric_degeneracy"]
print(
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
)

893
solver/datagen/generator.py Normal file
View File

@@ -0,0 +1,893 @@
"""Synthetic assembly graph generator for training data production.
Generates assembly graphs with known constraint classifications using
the pebble game and Jacobian verification. Each assembly is fully labeled
with per-constraint independence flags and assembly-level classification.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from scipy.spatial.transform import Rotation
from solver.datagen.analysis import analyze_assembly
from solver.datagen.labeling import label_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
if TYPE_CHECKING:
from typing import Any
__all__ = [
"COMPLEXITY_RANGES",
"AxisStrategy",
"ComplexityTier",
"SyntheticAssemblyGenerator",
]
# ---------------------------------------------------------------------------
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
# ---------------------------------------------------------------------------
ComplexityTier = Literal["simple", "medium", "complex"]
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
"simple": (2, 6),
"medium": (6, 16),
"complex": (16, 51),
}
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
class SyntheticAssemblyGenerator:
"""Generates assembly graphs with known minimal constraint sets.
Uses the pebble game to incrementally build assemblies, tracking
exactly which constraints are independent at each step. This produces
labeled training data: (assembly_graph, constraint_set, labels).
Labels per constraint:
- independent: bool (does this constraint remove a DOF?)
- redundant: bool (is this constraint overconstrained?)
- minimal_set: bool (part of a minimal rigidity basis?)
"""
def __init__(self, seed: int = 42) -> None:
self.rng = np.random.default_rng(seed)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _random_position(self, scale: float = 5.0) -> np.ndarray:
"""Generate random 3D position within [-scale, scale] cube."""
return self.rng.uniform(-scale, scale, size=3)
def _random_axis(self) -> np.ndarray:
"""Generate random normalized 3D axis."""
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
return axis
def _random_orientation(self) -> np.ndarray:
"""Generate a random 3x3 rotation matrix."""
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
return mat
def _cardinal_axis(self) -> np.ndarray:
"""Pick uniformly from the six signed cardinal directions."""
axes = np.array(
[
[1, 0, 0],
[-1, 0, 0],
[0, 1, 0],
[0, -1, 0],
[0, 0, 1],
[0, 0, -1],
],
dtype=float,
)
result: np.ndarray = axes[int(self.rng.integers(6))]
return result
def _near_parallel_axis(
self,
base_axis: np.ndarray,
perturbation_scale: float = 0.05,
) -> np.ndarray:
"""Return *base_axis* with a small random perturbation, re-normalized."""
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
return perturbed / np.linalg.norm(perturbed)
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
"""Sample a joint axis using the specified strategy."""
if strategy == "cardinal":
return self._cardinal_axis()
if strategy == "near_parallel":
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
return self._random_axis()
def _resolve_axis(
self,
strategy: AxisStrategy,
parallel_axis_prob: float,
shared_axis: np.ndarray | None,
) -> tuple[np.ndarray, np.ndarray | None]:
"""Return (axis_for_this_joint, shared_axis_to_propagate).
On the first call where *shared_axis* is ``None`` and parallel
injection triggers, a base axis is chosen and returned as
*shared_axis* for subsequent calls.
"""
if shared_axis is not None:
return self._near_parallel_axis(shared_axis), shared_axis
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
base = self._sample_axis(strategy)
return base.copy(), base
return self._sample_axis(strategy), None
def _select_joint_type(
self,
joint_types: JointType | list[JointType],
) -> JointType:
"""Select a joint type from a single type or list."""
if isinstance(joint_types, list):
idx = int(self.rng.integers(0, len(joint_types)))
return joint_types[idx]
return joint_types
def _create_joint(
self,
joint_id: int,
body_a_id: int,
body_b_id: int,
pos_a: np.ndarray,
pos_b: np.ndarray,
joint_type: JointType,
*,
axis: np.ndarray | None = None,
orient_a: np.ndarray | None = None,
orient_b: np.ndarray | None = None,
) -> Joint:
"""Create a joint between two bodies.
When orientations are provided, anchor points are offset from
each body's center along a random local direction rotated into
world frame, rather than placed at the midpoint.
"""
if orient_a is not None and orient_b is not None:
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
offset_scale = dist * 0.2
local_a = self.rng.standard_normal(3) * offset_scale
local_b = self.rng.standard_normal(3) * offset_scale
anchor_a = pos_a + orient_a @ local_a
anchor_b = pos_b + orient_b @ local_b
else:
anchor = (pos_a + pos_b) / 2.0
anchor_a = anchor
anchor_b = anchor
return Joint(
joint_id=joint_id,
body_a=body_a_id,
body_b=body_b_id,
joint_type=joint_type,
anchor_a=anchor_a,
anchor_b=anchor_b,
axis=axis if axis is not None else self._random_axis(),
)
# ------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ------------------------------------------------------------------
def generate_chain_assembly(
self,
n_bodies: int,
joint_type: JointType = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a serial kinematic chain.
Simple but useful: each body connects to the next with the
specified joint type. Results in an underconstrained assembly
(serial chain is never rigid without closing loops).
"""
bodies = []
joints = []
for i in range(n_bodies):
pos = np.array([i * 2.0, 0.0, 0.0])
bodies.append(
RigidBody(
body_id=i,
position=pos,
orientation=self._random_orientation(),
)
)
shared_axis: np.ndarray | None = None
for i in range(n_bodies - 1):
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i,
i,
i + 1,
bodies[i].position,
bodies[i + 1].position,
joint_type,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[i + 1].orientation,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_rigid_assembly(
self,
n_bodies: int,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a minimally rigid assembly by adding joints until rigid.
Strategy: start with fixed joints on a spanning tree (guarantees
rigidity), then randomly relax some to weaker joint types while
maintaining rigidity via the pebble game check.
"""
bodies = []
for i in range(n_bodies):
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Build spanning tree with fixed joints (overconstrained)
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i - 1,
parent,
i,
bodies[parent].position,
bodies[i].position,
JointType.FIXED,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
# Try relaxing joints to weaker types while maintaining rigidity
weaker_types = [
JointType.REVOLUTE,
JointType.CYLINDRICAL,
JointType.BALL,
]
ground = 0 if grounded else None
for idx in self.rng.permutation(len(joints)):
original_type = joints[idx].joint_type
for candidate in weaker_types:
joints[idx].joint_type = candidate
analysis = analyze_assembly(bodies, joints, ground_body=ground)
if analysis.is_rigid:
break # Keep the weaker type
else:
joints[idx].joint_type = original_type
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
def generate_overconstrained_assembly(
self,
n_bodies: int,
extra_joints: int = 2,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate an assembly with known redundant constraints.
Starts with a rigid assembly, then adds extra joints that
the pebble game will flag as redundant.
"""
bodies, joints, _ = self.generate_rigid_assembly(
n_bodies,
grounded=grounded,
axis_strategy=axis_strategy,
parallel_axis_prob=parallel_axis_prob,
)
joint_id = len(joints)
shared_axis: np.ndarray | None = None
for _ in range(extra_joints):
a, b = self.rng.choice(n_bodies, size=2, replace=False)
_overcon_types = [
JointType.REVOLUTE,
JointType.FIXED,
JointType.BALL,
]
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
int(a),
int(b),
bodies[int(a)].position,
bodies[int(b)].position,
jtype,
axis=axis,
orient_a=bodies[int(a)].orientation,
orient_b=bodies[int(b)].orientation,
)
)
joint_id += 1
ground = 0 if grounded else None
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
# ------------------------------------------------------------------
# New topology generators
# ------------------------------------------------------------------
def generate_tree_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
branching_factor: int = 3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a random tree topology with configurable branching.
Creates a tree where each body can have up to *branching_factor*
children. Different branches can use different joint types if a
list is provided. Always underconstrained (no closed loops).
Args:
n_bodies: Total bodies (root + children).
joint_types: Single type or list to sample from per joint.
branching_factor: Max children per parent (1-5 recommended).
"""
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=self._random_orientation(),
)
]
joints: list[Joint] = []
available_parents = [0]
next_id = 1
joint_id = 0
shared_axis: np.ndarray | None = None
while next_id < n_bodies and available_parents:
pidx = int(self.rng.integers(0, len(available_parents)))
parent_id = available_parents[pidx]
parent_pos = bodies[parent_id].position
max_children = min(branching_factor, n_bodies - next_id)
n_children = int(self.rng.integers(1, max_children + 1))
for _ in range(n_children):
direction = self._random_axis()
distance = self.rng.uniform(1.5, 3.0)
child_pos = parent_pos + direction * distance
child_orient = self._random_orientation()
bodies.append(
RigidBody(
body_id=next_id,
position=child_pos,
orientation=child_orient,
)
)
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
parent_id,
next_id,
parent_pos,
child_pos,
jtype,
axis=axis,
orient_a=bodies[parent_id].orientation,
orient_b=child_orient,
)
)
available_parents.append(next_id)
next_id += 1
joint_id += 1
if next_id >= n_bodies:
break
# Retire parent if it reached branching limit or randomly
if n_children >= branching_factor or self.rng.random() < 0.3:
available_parents.pop(pidx)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_loop_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a single closed loop (ring) of bodies.
The closing constraint introduces redundancy, making this
useful for generating overconstrained training data.
Args:
n_bodies: Bodies in the ring (>= 3).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 3.
"""
if n_bodies < 3:
msg = "Loop assembly requires at least 3 bodies"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
base_radius = max(2.0, n_bodies * 0.4)
for i in range(n_bodies):
angle = 2 * np.pi * i / n_bodies
radius = base_radius + self.rng.uniform(-0.5, 0.5)
x = radius * np.cos(angle)
y = radius * np.sin(angle)
z = float(self.rng.uniform(-0.2, 0.2))
bodies.append(
RigidBody(
body_id=i,
position=np.array([x, y, z]),
orientation=self._random_orientation(),
)
)
shared_axis: np.ndarray | None = None
for i in range(n_bodies):
next_i = (i + 1) % n_bodies
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i,
i,
next_i,
bodies[i].position,
bodies[next_i].position,
jtype,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[next_i].orientation,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_star_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a star topology with central hub and satellites.
Body 0 is the hub; all other bodies connect directly to it.
Underconstrained because there are no inter-satellite connections.
Args:
n_bodies: Total bodies including hub (>= 2).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 2.
"""
if n_bodies < 2:
msg = "Star assembly requires at least 2 bodies"
raise ValueError(msg)
hub_orient = self._random_orientation()
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=hub_orient,
)
]
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
direction = self._random_axis()
distance = self.rng.uniform(2.0, 5.0)
pos = direction * distance
sat_orient = self._random_orientation()
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i - 1,
0,
i,
np.zeros(3),
pos,
jtype,
axis=axis,
orient_a=hub_orient,
orient_b=sat_orient,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_mixed_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
edge_density: float = 0.3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a mixed topology combining tree and loop elements.
Builds a spanning tree for connectivity, then adds extra edges
based on *edge_density* to create loops and redundancy.
Args:
n_bodies: Number of bodies.
joint_types: Single type or list to sample from per joint.
edge_density: Fraction of non-tree edges to add (0.0-1.0).
Raises:
ValueError: If *edge_density* not in [0.0, 1.0].
"""
if not 0.0 <= edge_density <= 1.0:
msg = "edge_density must be in [0.0, 1.0]"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
for i in range(n_bodies):
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Phase 1: spanning tree
joint_id = 0
existing_edges: set[frozenset[int]] = set()
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
parent,
i,
bodies[parent].position,
bodies[i].position,
jtype,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
existing_edges.add(frozenset([parent, i]))
joint_id += 1
# Phase 2: extra edges based on density
candidates: list[tuple[int, int]] = []
for i in range(n_bodies):
for j in range(i + 1, n_bodies):
if frozenset([i, j]) not in existing_edges:
candidates.append((i, j))
n_extra = int(edge_density * len(candidates))
self.rng.shuffle(candidates)
for a, b in candidates[:n_extra]:
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
a,
b,
bodies[a].position,
bodies[b].position,
jtype,
axis=axis,
orient_a=bodies[a].orientation,
orient_b=bodies[b].orientation,
)
)
joint_id += 1
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
# ------------------------------------------------------------------
# Batch generation
# ------------------------------------------------------------------
def generate_training_batch(
self,
batch_size: int = 100,
n_bodies_range: tuple[int, int] | None = None,
complexity_tier: ComplexityTier | None = None,
*,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
grounded_ratio: float = 1.0,
) -> list[dict[str, Any]]:
"""Generate a batch of labeled training examples.
Each example contains body positions, joint descriptions,
per-joint independence labels, and assembly-level classification.
Args:
batch_size: Number of assemblies to generate.
n_bodies_range: ``(min, max_exclusive)`` body count.
Overridden by *complexity_tier* when both are given.
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
/ ``"complex"``). Overrides *n_bodies_range*.
axis_strategy: Axis sampling strategy for joint axes.
parallel_axis_prob: Probability of parallel axis injection.
grounded_ratio: Fraction of examples that are grounded.
"""
if complexity_tier is not None:
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
elif n_bodies_range is None:
n_bodies_range = (3, 8)
_joint_pool = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
JointType.FIXED,
]
geo_kw: dict[str, Any] = {
"axis_strategy": axis_strategy,
"parallel_axis_prob": parallel_axis_prob,
}
examples: list[dict[str, Any]] = []
for i in range(batch_size):
n = int(self.rng.integers(*n_bodies_range))
gen_idx = int(self.rng.integers(7))
grounded = bool(self.rng.random() < grounded_ratio)
if gen_idx == 0:
_chain_types = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
]
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
bodies, joints, analysis = self.generate_chain_assembly(
n,
jtype,
grounded=grounded,
**geo_kw,
)
gen_name = "chain"
elif gen_idx == 1:
bodies, joints, analysis = self.generate_rigid_assembly(
n,
grounded=grounded,
**geo_kw,
)
gen_name = "rigid"
elif gen_idx == 2:
extra = int(self.rng.integers(1, 4))
bodies, joints, analysis = self.generate_overconstrained_assembly(
n,
extra,
grounded=grounded,
**geo_kw,
)
gen_name = "overconstrained"
elif gen_idx == 3:
branching = int(self.rng.integers(2, 5))
bodies, joints, analysis = self.generate_tree_assembly(
n,
_joint_pool,
branching,
grounded=grounded,
**geo_kw,
)
gen_name = "tree"
elif gen_idx == 4:
n = max(n, 3)
bodies, joints, analysis = self.generate_loop_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "loop"
elif gen_idx == 5:
n = max(n, 2)
bodies, joints, analysis = self.generate_star_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "star"
else:
density = float(self.rng.uniform(0.2, 0.5))
bodies, joints, analysis = self.generate_mixed_assembly(
n,
_joint_pool,
density,
grounded=grounded,
**geo_kw,
)
gen_name = "mixed"
# Produce ground truth labels (includes ConstraintAnalysis)
ground = 0 if grounded else None
labels = label_assembly(bodies, joints, ground_body=ground)
analysis = labels.analysis
# Build per-joint labels from edge results
joint_labels: dict[int, dict[str, int]] = {}
for result in analysis.per_edge_results:
jid = result["joint_id"]
if jid not in joint_labels:
joint_labels[jid] = {
"independent_constraints": 0,
"redundant_constraints": 0,
"total_constraints": 0,
}
joint_labels[jid]["total_constraints"] += 1
if result["independent"]:
joint_labels[jid]["independent_constraints"] += 1
else:
joint_labels[jid]["redundant_constraints"] += 1
examples.append(
{
"example_id": i,
"generator_type": gen_name,
"grounded": grounded,
"n_bodies": len(bodies),
"n_joints": len(joints),
"body_positions": [b.position.tolist() for b in bodies],
"body_orientations": [b.orientation.tolist() for b in bodies],
"joints": [
{
"joint_id": j.joint_id,
"body_a": j.body_a,
"body_b": j.body_b,
"type": j.joint_type.name,
"axis": j.axis.tolist(),
}
for j in joints
],
"joint_labels": joint_labels,
"labels": labels.to_dict(),
"assembly_classification": (analysis.combinatorial_classification),
"is_rigid": analysis.is_rigid,
"is_minimally_rigid": analysis.is_minimally_rigid,
"internal_dof": analysis.jacobian_internal_dof,
"geometric_degeneracies": (analysis.geometric_degeneracies),
}
)
return examples

517
solver/datagen/jacobian.py Normal file
View File

@@ -0,0 +1,517 @@
"""Numerical Jacobian rank verification for assembly constraint analysis.
Builds the constraint Jacobian matrix and analyzes its numerical rank
to detect geometric degeneracies that the combinatorial pebble game
cannot identify (e.g., parallel revolute axes creating hidden dependencies).
References:
- Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D"
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.types import Joint, JointType, RigidBody
if TYPE_CHECKING:
from typing import Any
__all__ = ["JacobianVerifier"]
class JacobianVerifier:
"""Builds and analyzes the constraint Jacobian for numerical rank check.
The pebble game gives a combinatorial *necessary* condition for
rigidity. However, geometric special cases (e.g., all revolute axes
parallel, creating a hidden dependency) require numerical verification.
For each joint, we construct the constraint Jacobian rows that map
the 6n-dimensional generalized velocity vector to the constraint
violation rates. The rank of this Jacobian equals the number of
truly independent constraints.
The generalized velocity vector for n bodies is::
v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z]
Each scalar constraint C_i contributes one row to J such that::
dC_i/dt = J_i @ v = 0
"""
def __init__(self, bodies: list[RigidBody]) -> None:
self.bodies = {b.body_id: b for b in bodies}
self.body_index = {b.body_id: i for i, b in enumerate(bodies)}
self.n_bodies = len(bodies)
self.jacobian_rows: list[np.ndarray] = []
self.row_labels: list[dict[str, Any]] = []
def _body_cols(self, body_id: int) -> tuple[int, int]:
"""Return the column range [start, end) for a body in J."""
idx = self.body_index[body_id]
return idx * 6, (idx + 1) * 6
def add_joint_constraints(self, joint: Joint) -> int:
"""Add Jacobian rows for all scalar constraints of a joint.
Returns the number of rows added.
"""
builder = {
JointType.FIXED: self._build_fixed,
JointType.REVOLUTE: self._build_revolute,
JointType.CYLINDRICAL: self._build_cylindrical,
JointType.SLIDER: self._build_slider,
JointType.BALL: self._build_ball,
JointType.PLANAR: self._build_planar,
JointType.DISTANCE: self._build_distance,
JointType.PARALLEL: self._build_parallel,
JointType.PERPENDICULAR: self._build_perpendicular,
JointType.UNIVERSAL: self._build_universal,
JointType.SCREW: self._build_screw,
}
rows_before = len(self.jacobian_rows)
builder[joint.joint_type](joint)
return len(self.jacobian_rows) - rows_before
def _make_row(self) -> np.ndarray:
"""Create a zero row of width 6*n_bodies."""
return np.zeros(6 * self.n_bodies)
def _skew(self, v: np.ndarray) -> np.ndarray:
"""Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``."""
return np.array(
[
[0, -v[2], v[1]],
[v[2], 0, -v[0]],
[-v[1], v[0], 0],
]
)
# --- Ball-and-socket (spherical) joint: 3 translation constraints ---
def _build_ball(self, joint: Joint) -> None:
"""Ball joint: coincident point constraint.
``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0``
(3 equations)
Jacobian rows (for each of x, y, z):
body_a linear: -I
body_a angular: +skew(R_a @ r_a)
body_b linear: +I
body_b angular: -skew(R_b @ r_b)
"""
# Use anchor positions directly as world-frame offsets
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
e = np.zeros(3)
e[axis_idx] = 1.0
row[col_a_start : col_a_start + 3] = -e
row[col_a_start + 3 : col_a_end] = np.cross(r_a, e)
row[col_b_start : col_b_start + 3] = e
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "ball_translation",
"axis": axis_idx,
}
)
# --- Fixed joint: 3 translation + 3 rotation constraints ---
def _build_fixed(self, joint: Joint) -> None:
"""Fixed joint = ball joint + locked rotation.
Translation part: same as ball joint (3 rows).
Rotation part: relative angular velocity must be zero (3 rows).
"""
self._build_ball(joint)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "fixed_rotation",
"axis": axis_idx,
}
)
# --- Revolute (hinge) joint: 3 translation + 2 rotation constraints ---
def _build_revolute(self, joint: Joint) -> None:
"""Revolute joint: rotation only about one axis.
Translation: same as ball (3 rows).
Rotation: relative angular velocity must be parallel to hinge axis.
"""
self._build_ball(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "revolute_rotation",
"perp_axis": i,
}
)
# --- Cylindrical joint: 2 translation + 2 rotation constraints ---
def _build_cylindrical(self, joint: Joint) -> None:
"""Cylindrical joint: allows rotation + translation along one axis.
Translation: constrain motion perpendicular to axis (2 rows).
Rotation: constrain rotation perpendicular to axis (2 rows).
"""
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start : col_a_start + 3] = -t
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
row[col_b_start : col_b_start + 3] = t
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "cylindrical_translation",
"perp_axis": i,
}
)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "cylindrical_rotation",
"perp_axis": i,
}
)
# --- Slider (prismatic) joint: 2 translation + 3 rotation constraints ---
def _build_slider(self, joint: Joint) -> None:
"""Slider/prismatic joint: translation along one axis only.
Translation: perpendicular translation constrained (2 rows).
Rotation: all relative rotation constrained (3 rows).
"""
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start : col_a_start + 3] = -t
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
row[col_b_start : col_b_start + 3] = t
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "slider_translation",
"perp_axis": i,
}
)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "slider_rotation",
"axis": axis_idx,
}
)
# --- Planar joint: 1 translation + 2 rotation constraints ---
def _build_planar(self, joint: Joint) -> None:
"""Planar joint: constrains to a plane.
Translation: motion along plane normal constrained (1 row).
Rotation: rotation about axes in the plane constrained (2 rows).
"""
normal = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(normal)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -normal
row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal)
row[col_b_start : col_b_start + 3] = normal
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "planar_translation",
}
)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "planar_rotation",
"perp_axis": i,
}
)
# --- Distance constraint: 1 scalar ---
def _build_distance(self, joint: Joint) -> None:
"""Distance constraint: ``||p_b - p_a|| = d``.
Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0``
where ``direction = normalized(p_b - p_a)``.
"""
p_a = joint.anchor_a
p_b = joint.anchor_b
diff = p_b - p_a
dist = np.linalg.norm(diff)
direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -direction
row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction)
row[col_b_start : col_b_start + 3] = direction
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "distance",
}
)
# --- Parallel constraint: 3 rotation constraints ---
def _build_parallel(self, joint: Joint) -> None:
"""Parallel: all relative rotation constrained (same as fixed rotation).
In practice only 2 of 3 are independent for a single axis, but
we emit 3 and let the rank check sort it out.
"""
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "parallel_rotation",
"axis": axis_idx,
}
)
# --- Perpendicular constraint: 1 angular ---
def _build_perpendicular(self, joint: Joint) -> None:
"""Perpendicular: single dot-product angular constraint."""
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -axis
row[col_b_start + 3 : col_b_start + 6] = axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "perpendicular",
}
)
# --- Universal (Cardan) joint: 3 translation + 1 rotation ---
def _build_universal(self, joint: Joint) -> None:
"""Universal joint: ball + one rotation constraint.
Allows rotation about two axes, constrains rotation about the third.
"""
self._build_ball(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -axis
row[col_b_start + 3 : col_b_start + 6] = axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "universal_rotation",
}
)
# --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled ---
def _build_screw(self, joint: Joint) -> None:
"""Screw joint: coupled rotation-translation along axis.
Like cylindrical but with a coupling constraint:
``v_axial - pitch * w_axial = 0``
"""
self._build_cylindrical(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -axis
row[col_b_start : col_b_start + 3] = axis
row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis
row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "screw_coupling",
}
)
# --- Utilities ---
def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Generate two unit vectors perpendicular to *axis* and each other."""
if abs(axis[0]) < 0.9:
t1 = np.cross(axis, np.array([1.0, 0, 0]))
else:
t1 = np.cross(axis, np.array([0, 1.0, 0]))
t1 /= np.linalg.norm(t1)
t2 = np.cross(axis, t1)
t2 /= np.linalg.norm(t2)
return t1, t2
def get_jacobian(self) -> np.ndarray:
"""Return the full constraint Jacobian matrix."""
if not self.jacobian_rows:
return np.zeros((0, 6 * self.n_bodies))
return np.array(self.jacobian_rows)
def numerical_rank(self, tol: float = 1e-8) -> int:
"""Compute the numerical rank of the constraint Jacobian via SVD.
This is the number of truly independent scalar constraints,
accounting for geometric degeneracies that the combinatorial
pebble game cannot detect.
"""
j = self.get_jacobian()
if j.size == 0:
return 0
sv = np.linalg.svd(j, compute_uv=False)
return int(np.sum(sv > tol))
def find_dependencies(self, tol: float = 1e-8) -> list[int]:
"""Identify which constraint rows are numerically dependent.
Returns indices of rows that can be removed without changing
the Jacobian's rank.
"""
j = self.get_jacobian()
if j.size == 0:
return []
n_rows = j.shape[0]
dependent: list[int] = []
current = np.zeros((0, j.shape[1]))
current_rank = 0
for i in range(n_rows):
candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :]
sv = np.linalg.svd(candidate, compute_uv=False)
new_rank = int(np.sum(sv > tol))
if new_rank > current_rank:
current = candidate
current_rank = new_rank
else:
dependent.append(i)
return dependent

394
solver/datagen/labeling.py Normal file
View File

@@ -0,0 +1,394 @@
"""Ground truth labeling pipeline for synthetic assemblies.
Produces rich per-constraint, per-joint, per-body, and assembly-level
labels by running both the pebble game and Jacobian verification and
correlating their results.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
if TYPE_CHECKING:
from typing import Any
__all__ = ["AssemblyLabels", "label_assembly"]
_GROUND_ID = -1
_SVD_TOL = 1e-8
# ---------------------------------------------------------------------------
# Label dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ConstraintLabel:
"""Per scalar-constraint label combining both analysis methods."""
joint_id: int
constraint_idx: int
pebble_independent: bool
jacobian_independent: bool
@dataclass
class JointLabel:
"""Aggregated constraint counts for a single joint."""
joint_id: int
independent_count: int
redundant_count: int
total: int
@dataclass
class BodyDofLabel:
"""Per-body DOF signature from nullspace projection."""
body_id: int
translational_dof: int
rotational_dof: int
@dataclass
class AssemblyLabel:
"""Assembly-wide summary label."""
classification: str
total_dof: int
redundant_count: int
is_rigid: bool
is_minimally_rigid: bool
has_degeneracy: bool
@dataclass
class AssemblyLabels:
"""Complete ground truth labels for an assembly."""
per_constraint: list[ConstraintLabel]
per_joint: list[JointLabel]
per_body: list[BodyDofLabel]
assembly: AssemblyLabel
analysis: ConstraintAnalysis
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"per_constraint": [
{
"joint_id": c.joint_id,
"constraint_idx": c.constraint_idx,
"pebble_independent": c.pebble_independent,
"jacobian_independent": c.jacobian_independent,
}
for c in self.per_constraint
],
"per_joint": [
{
"joint_id": j.joint_id,
"independent_count": j.independent_count,
"redundant_count": j.redundant_count,
"total": j.total,
}
for j in self.per_joint
],
"per_body": [
{
"body_id": b.body_id,
"translational_dof": b.translational_dof,
"rotational_dof": b.rotational_dof,
}
for b in self.per_body
],
"assembly": {
"classification": self.assembly.classification,
"total_dof": self.assembly.total_dof,
"redundant_count": self.assembly.redundant_count,
"is_rigid": self.assembly.is_rigid,
"is_minimally_rigid": self.assembly.is_minimally_rigid,
"has_degeneracy": self.assembly.has_degeneracy,
},
}
# ---------------------------------------------------------------------------
# Per-body DOF from nullspace projection
# ---------------------------------------------------------------------------
def _compute_per_body_dof(
j_reduced: np.ndarray,
body_ids: list[int],
ground_body: int | None,
body_index: dict[int, int],
) -> list[BodyDofLabel]:
"""Compute translational and rotational DOF per body.
Uses SVD nullspace projection: for each body, extract its
translational (3 cols) and rotational (3 cols) components
from the nullspace basis and compute ranks.
"""
# Build column index mapping for the reduced Jacobian
# (ground body columns have been removed)
col_map: dict[int, int] = {}
col_idx = 0
for bid in body_ids:
if bid == ground_body:
continue
col_map[bid] = col_idx
col_idx += 1
results: list[BodyDofLabel] = []
if j_reduced.size == 0:
# No constraints — every body is fully free
for bid in body_ids:
if bid == ground_body:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
else:
results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3))
return results
# Full SVD to get nullspace
_u, s, vh = np.linalg.svd(j_reduced, full_matrices=True)
rank = int(np.sum(s > _SVD_TOL))
n_cols = j_reduced.shape[1]
if rank >= n_cols:
# Fully constrained — no nullspace
for bid in body_ids:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
return results
# Nullspace basis: rows of Vh beyond the rank
nullspace = vh[rank:] # shape: (n_cols - rank, n_cols)
for bid in body_ids:
if bid == ground_body:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
continue
idx = col_map[bid]
trans_cols = nullspace[:, idx * 6 : idx * 6 + 3]
rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6]
# Rank of each block = DOF in that category
if trans_cols.size > 0:
sv_t = np.linalg.svd(trans_cols, compute_uv=False)
t_dof = int(np.sum(sv_t > _SVD_TOL))
else:
t_dof = 0
if rot_cols.size > 0:
sv_r = np.linalg.svd(rot_cols, compute_uv=False)
r_dof = int(np.sum(sv_r > _SVD_TOL))
else:
r_dof = 0
results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof))
return results
# ---------------------------------------------------------------------------
# Main labeling function
# ---------------------------------------------------------------------------
def label_assembly(
bodies: list[RigidBody],
joints: list[Joint],
ground_body: int | None = None,
) -> AssemblyLabels:
"""Produce complete ground truth labels for an assembly.
Runs both the pebble game and Jacobian verification internally,
then correlates their results into per-constraint, per-joint,
per-body, and assembly-level labels.
Args:
bodies: Rigid bodies in the assembly.
joints: Joints connecting the bodies.
ground_body: If set, this body is fixed to the world.
Returns:
AssemblyLabels with full label set and embedded ConstraintAnalysis.
"""
# ---- Pebble Game ----
pg = PebbleGame3D()
all_edge_results: list[dict[str, Any]] = []
if ground_body is not None:
pg.add_body(_GROUND_ID)
for body in bodies:
pg.add_body(body.body_id)
if ground_body is not None:
ground_joint = Joint(
joint_id=-1,
body_a=ground_body,
body_b=_GROUND_ID,
joint_type=JointType.FIXED,
anchor_a=bodies[0].position if bodies else np.zeros(3),
anchor_b=bodies[0].position if bodies else np.zeros(3),
)
pg.add_joint(ground_joint)
for joint in joints:
results = pg.add_joint(joint)
all_edge_results.extend(results)
grounded = ground_body is not None
combinatorial_independent = len(pg.state.independent_edges)
raw_dof = pg.get_dof()
ground_offset = 6 if grounded else 0
effective_dof = raw_dof - ground_offset
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
redundant_count = pg.get_redundant_count()
if redundant_count > 0 and effective_internal_dof > 0:
classification = "mixed"
elif redundant_count > 0:
classification = "overconstrained"
elif effective_internal_dof > 0:
classification = "underconstrained"
else:
classification = "well-constrained"
# ---- Jacobian Verification ----
verifier = JacobianVerifier(bodies)
for joint in joints:
verifier.add_joint_constraints(joint)
j_full = verifier.get_jacobian()
j_reduced = j_full.copy()
if ground_body is not None and j_reduced.size > 0:
idx = verifier.body_index[ground_body]
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
j_reduced = np.delete(j_reduced, cols_to_remove, axis=1)
if j_reduced.size > 0:
sv = np.linalg.svd(j_reduced, compute_uv=False)
jacobian_rank = int(np.sum(sv > _SVD_TOL))
else:
jacobian_rank = 0
n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies)
jacobian_nullity = n_cols - jacobian_rank
dependent_rows = verifier.find_dependencies()
dependent_set = set(dependent_rows)
trivial_dof = 0 if grounded else 6
jacobian_internal_dof = jacobian_nullity - trivial_dof
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
is_rigid = jacobian_nullity <= trivial_dof
is_minimally_rigid = is_rigid and len(dependent_rows) == 0
# ---- Per-constraint labels ----
# Map Jacobian rows to (joint_id, constraint_index).
# Rows are added contiguously per joint in the same order as joints.
row_to_joint: list[tuple[int, int]] = []
for joint in joints:
dof = joint.joint_type.dof
for ci in range(dof):
row_to_joint.append((joint.joint_id, ci))
per_constraint: list[ConstraintLabel] = []
for edge_idx, edge_result in enumerate(all_edge_results):
jid = edge_result["joint_id"]
ci = edge_result["constraint_index"]
pebble_indep = edge_result["independent"]
# Find matching Jacobian row
jacobian_indep = True
if edge_idx < len(row_to_joint):
row_idx = edge_idx
jacobian_indep = row_idx not in dependent_set
per_constraint.append(
ConstraintLabel(
joint_id=jid,
constraint_idx=ci,
pebble_independent=pebble_indep,
jacobian_independent=jacobian_indep,
)
)
# ---- Per-joint labels ----
joint_agg: dict[int, JointLabel] = {}
for cl in per_constraint:
if cl.joint_id not in joint_agg:
joint_agg[cl.joint_id] = JointLabel(
joint_id=cl.joint_id,
independent_count=0,
redundant_count=0,
total=0,
)
jl = joint_agg[cl.joint_id]
jl.total += 1
if cl.pebble_independent:
jl.independent_count += 1
else:
jl.redundant_count += 1
per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg]
# ---- Per-body DOF labels ----
body_ids = [b.body_id for b in bodies]
per_body = _compute_per_body_dof(
j_reduced,
body_ids,
ground_body,
verifier.body_index,
)
# ---- Assembly label ----
assembly_label = AssemblyLabel(
classification=classification,
total_dof=max(0, jacobian_internal_dof),
redundant_count=redundant_count,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
has_degeneracy=geometric_degeneracies > 0,
)
# ---- ConstraintAnalysis (for backward compat) ----
analysis = ConstraintAnalysis(
combinatorial_dof=effective_dof,
combinatorial_internal_dof=effective_internal_dof,
combinatorial_redundant=redundant_count,
combinatorial_classification=classification,
per_edge_results=all_edge_results,
jacobian_rank=jacobian_rank,
jacobian_nullity=jacobian_nullity,
jacobian_internal_dof=max(0, jacobian_internal_dof),
numerically_dependent=dependent_rows,
geometric_degeneracies=geometric_degeneracies,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
)
return AssemblyLabels(
per_constraint=per_constraint,
per_joint=per_joint,
per_body=per_body,
assembly=assembly_label,
analysis=analysis,
)

View File

@@ -0,0 +1,258 @@
"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis.
Implements the pebble game algorithm adapted for CAD assembly constraint
graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints
between bodies remove DOF according to their type.
The pebble game provides a fast combinatorial *necessary* condition for
rigidity via Tay's theorem. It does not detect geometric degeneracies —
use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient*
condition.
References:
- Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008
- Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity
Percolation: The Pebble Game", J. Comput. Phys., 1997
- Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from solver.datagen.types import Joint, PebbleState
if TYPE_CHECKING:
from typing import Any
__all__ = ["PebbleGame3D"]
class PebbleGame3D:
"""Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks.
For body-bar-hinge structures in 3D, Tay's theorem states that a
multigraph G on n vertices is generically minimally rigid iff:
|E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2)
The (6,6)-pebble game tests this sparsity condition incrementally.
Each vertex starts with 6 pebbles (representing 6 DOF). To insert
an edge, we need to collect 6+1=7 pebbles on its two endpoints.
If we can, the edge is independent (removes a DOF). If not, it's
redundant (overconstrained).
In the CAD assembly context:
- Vertices = rigid bodies
- Edges = scalar constraints from joints
- A revolute joint (5 DOF removed) maps to 5 multigraph edges
- A fixed joint (6 DOF removed) maps to 6 multigraph edges
"""
K = 6 # Pebbles per vertex (DOF per rigid body in 3D)
L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge
def __init__(self) -> None:
self.state = PebbleState()
self._edge_counter = 0
self._bodies: set[int] = set()
def add_body(self, body_id: int) -> None:
"""Register a rigid body (vertex) with K=6 free pebbles."""
if body_id in self._bodies:
return
self._bodies.add(body_id)
self.state.free_pebbles[body_id] = self.K
self.state.incoming[body_id] = set()
self.state.outgoing[body_id] = set()
def add_joint(self, joint: Joint) -> list[dict[str, Any]]:
"""Expand a joint into multigraph edges and test each for independence.
A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph.
Each edge is tested individually via the pebble game.
Returns a list of dicts, one per scalar constraint, with:
- edge_id: int
- independent: bool
- dof_remaining: int (total free pebbles after this edge)
"""
self.add_body(joint.body_a)
self.add_body(joint.body_b)
num_constraints = joint.joint_type.dof
results: list[dict[str, Any]] = []
for i in range(num_constraints):
edge_id = self._edge_counter
self._edge_counter += 1
independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b)
total_free = sum(self.state.free_pebbles.values())
results.append(
{
"edge_id": edge_id,
"joint_id": joint.joint_id,
"constraint_index": i,
"independent": independent,
"dof_remaining": total_free,
}
)
return results
def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool:
"""Try to insert a directed edge between u and v.
The edge is accepted (independent) iff we can collect L+1 = 7
pebbles on the two endpoints {u, v} combined.
If accepted, one pebble is consumed and the edge is directed
away from the vertex that gives up the pebble.
"""
# Count current free pebbles on u and v
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
# Try to gather enough pebbles via DFS reachability search
if available < self.L + 1:
needed = (self.L + 1) - available
# Try to free pebbles by searching from u first, then v
for target in (u, v):
while needed > 0:
found = self._search_and_collect(target, frozenset({u, v}))
if not found:
break
needed -= 1
# Recheck after collection attempts
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
if available >= self.L + 1:
# Accept: consume a pebble from whichever endpoint has one
source = u if self.state.free_pebbles[u] > 0 else v
self.state.free_pebbles[source] -= 1
self.state.directed_edges[edge_id] = (source, v if source == u else u)
self.state.outgoing[source].add((edge_id, v if source == u else u))
target = v if source == u else u
self.state.incoming[target].add((edge_id, source))
self.state.independent_edges.add(edge_id)
return True
else:
# Reject: edge is redundant (overconstrained)
self.state.redundant_edges.add(edge_id)
return False
def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool:
"""DFS to find a free pebble reachable from *target* and move it.
Follows directed edges *backwards* (from destination to source)
to find a vertex with a free pebble that isn't in *forbidden*.
When found, reverses the path to move the pebble to *target*.
Returns True if a pebble was successfully moved to target.
"""
# BFS/DFS through the directed graph following outgoing edges
# from target. An outgoing edge (target -> w) means target spent
# a pebble on that edge. If we can find a vertex with a free
# pebble, we reverse edges along the path to move it.
visited: set[int] = set()
# Stack: (current_vertex, path_of_edge_ids_to_reverse)
stack: list[tuple[int, list[int]]] = [(target, [])]
while stack:
current, path = stack.pop()
if current in visited:
continue
visited.add(current)
# Check if current vertex (not in forbidden, not target)
# has a free pebble
if (
current != target
and current not in forbidden
and self.state.free_pebbles[current] > 0
):
# Found a pebble — reverse the path
self._reverse_path(path, current)
return True
# Follow outgoing edges from current vertex
for eid, neighbor in self.state.outgoing.get(current, set()):
if neighbor not in visited:
stack.append((neighbor, [*path, eid]))
return False
def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None:
"""Reverse directed edges along a path, moving a pebble to the start.
The pebble at *pebble_source* is consumed by the last edge in
the path, and a pebble is freed at the path's start vertex.
"""
if not edge_ids:
return
# Reverse each edge in the path
for eid in edge_ids:
old_source, old_target = self.state.directed_edges[eid]
# Remove from adjacency
self.state.outgoing[old_source].discard((eid, old_target))
self.state.incoming[old_target].discard((eid, old_source))
# Reverse direction
self.state.directed_edges[eid] = (old_target, old_source)
self.state.outgoing[old_target].add((eid, old_source))
self.state.incoming[old_source].add((eid, old_target))
# Move pebble counts: source loses one, first vertex in path gains one
self.state.free_pebbles[pebble_source] -= 1
# After all reversals, the vertex at the beginning of the
# search path gains a pebble
_first_src, first_tgt = self.state.directed_edges[edge_ids[0]]
self.state.free_pebbles[first_tgt] += 1
def get_dof(self) -> int:
"""Total remaining DOF = sum of free pebbles.
For a fully rigid assembly, this should be 6 (the trivial rigid
body motions of the whole assembly). Internal DOF = total - 6.
"""
return sum(self.state.free_pebbles.values())
def get_internal_dof(self) -> int:
"""Internal (non-trivial) degrees of freedom."""
return max(0, self.get_dof() - 6)
def is_rigid(self) -> bool:
"""Combinatorial rigidity check: rigid iff at most 6 pebbles remain."""
return self.get_dof() <= self.L
def get_redundant_count(self) -> int:
"""Number of redundant (overconstrained) scalar constraints."""
return len(self.state.redundant_edges)
def classify_assembly(self, *, grounded: bool = False) -> str:
"""Classify the assembly state.
Args:
grounded: If True, the baseline trivial DOF is 0 (not 6),
because the ground body's 6 DOF were removed.
"""
total_dof = self.get_dof()
redundant = self.get_redundant_count()
baseline = 0 if grounded else self.L
if redundant > 0 and total_dof > baseline:
return "mixed" # Both under and over-constrained regions
elif redundant > 0:
return "overconstrained"
elif total_dof > baseline:
return "underconstrained"
elif total_dof == baseline:
return "well-constrained"
else:
return "overconstrained"

144
solver/datagen/types.py Normal file
View File

@@ -0,0 +1,144 @@
"""Shared data types for assembly constraint analysis.
Types ported from the pebble-game synthetic data generator for reuse
across the solver package (data generation, training, inference).
"""
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"ConstraintAnalysis",
"Joint",
"JointType",
"PebbleState",
"RigidBody",
]
# ---------------------------------------------------------------------------
# Joint definitions: each joint type removes a known number of DOF
# ---------------------------------------------------------------------------
class JointType(enum.Enum):
"""Standard CAD joint types with their DOF-removal counts.
Each joint between two 6-DOF rigid bodies removes a specific number
of relative degrees of freedom. In the body-bar-hinge multigraph
representation, each joint maps to a number of edges equal to the
DOF it removes.
Values are ``(ordinal, dof_removed)`` tuples so that joint types
sharing the same DOF count remain distinct enum members. Use the
:attr:`dof` property to get the scalar constraint count.
"""
FIXED = (0, 6) # Locks all relative motion
REVOLUTE = (1, 5) # Allows rotation about one axis
CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis
SLIDER = (3, 5) # Allows translation along one axis (prismatic)
BALL = (4, 3) # Allows rotation about a point (spherical)
PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane
SCREW = (6, 5) # Coupled rotation-translation (helical)
UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint)
PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints)
PERPENDICULAR = (9, 1) # Single angular constraint
DISTANCE = (10, 1) # Single scalar distance constraint
@property
def dof(self) -> int:
"""Number of scalar constraints (DOF removed) by this joint type."""
return self.value[1]
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class RigidBody:
"""A rigid body in the assembly with pose and geometry info."""
body_id: int
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
# Anchor points for joints, in local frame
# Populated when joints reference specific geometry
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
@dataclass
class Joint:
"""A joint connecting two rigid bodies."""
joint_id: int
body_a: int # Index of first body
body_b: int # Index of second body
joint_type: JointType
# Joint parameters in world frame
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
axis: np.ndarray = field(
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
)
# For screw joints
pitch: float = 0.0
@dataclass
class PebbleState:
"""Tracks the state of the pebble game on the multigraph."""
# Number of free pebbles per body (vertex). Starts at 6.
free_pebbles: dict[int, int] = field(default_factory=dict)
# Directed edges: edge_id -> (source_body, target_body)
# Edge is directed away from the body that "spent" a pebble.
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
# Track which edges are independent vs redundant
independent_edges: set[int] = field(default_factory=set)
redundant_edges: set[int] = field(default_factory=set)
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
# Following directed edges *towards* a body (incoming edges)
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
# Outgoing edges from a body
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
@dataclass
class ConstraintAnalysis:
"""Results of analyzing an assembly's constraint system."""
# Pebble game (combinatorial) results
combinatorial_dof: int
combinatorial_internal_dof: int
combinatorial_redundant: int
combinatorial_classification: str
per_edge_results: list[dict[str, Any]]
# Numerical (Jacobian) results
jacobian_rank: int
jacobian_nullity: int # = 6n - rank = total DOF
jacobian_internal_dof: int # = nullity - 6
numerically_dependent: list[int]
# Combined
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
is_rigid: bool
is_minimally_rigid: bool

View File

View File

View File

View File

View File

0
tests/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,240 @@
"""Tests for solver.datagen.analysis -- combined analysis function."""
from __future__ import annotations
import numpy as np
import pytest
from solver.datagen.analysis import analyze_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _two_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
]
def _triangle_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([1.0, 1.7, 0.0])),
]
# ---------------------------------------------------------------------------
# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF)
# ---------------------------------------------------------------------------
class TestTwoBodiesRevolute:
"""Demo scenario 1: two bodies connected by a revolute joint."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _two_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
assert result.jacobian_internal_dof == 1
def test_not_rigid(self, result: ConstraintAnalysis) -> None:
assert not result.is_rigid
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification == "underconstrained"
def test_no_redundant(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_redundant == 0
# ---------------------------------------------------------------------------
# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF)
# ---------------------------------------------------------------------------
class TestTwoBodiesFixed:
"""Demo scenario 2: two bodies connected by a fixed joint."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _two_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
assert result.jacobian_internal_dof == 0
def test_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_rigid
def test_minimally_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_minimally_rigid
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification == "well-constrained"
# ---------------------------------------------------------------------------
# Scenario 3: Triangle with revolute joints (overconstrained)
# ---------------------------------------------------------------------------
class TestTriangleRevolute:
"""Demo scenario 3: triangle of 3 bodies + 3 revolute joints."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _triangle_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
1,
body_a=1,
body_b=2,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.5, 0.85, 0.0]),
anchor_b=np.array([1.5, 0.85, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
2,
body_a=2,
body_b=0,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([0.5, 0.85, 0.0]),
anchor_b=np.array([0.5, 0.85, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_has_redundant(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_redundant > 0
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification in ("overconstrained", "mixed")
def test_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_rigid
def test_numerically_dependent(self, result: ConstraintAnalysis) -> None:
assert len(result.numerically_dependent) > 0
# ---------------------------------------------------------------------------
# Scenario 4: Parallel revolute axes (geometric degeneracy)
# ---------------------------------------------------------------------------
class TestParallelRevoluteAxes:
"""Demo scenario 4: parallel revolute axes create geometric degeneracies."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
]
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
1,
body_a=1,
body_b=2,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([3.0, 0.0, 0.0]),
anchor_b=np.array([3.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None:
"""Parallel axes produce at least one geometric degeneracy."""
assert result.geometric_degeneracies > 0
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestNoJoints:
"""Assembly with bodies but no joints."""
def test_all_dof_free(self) -> None:
bodies = _two_bodies()
result = analyze_assembly(bodies, [], ground_body=0)
# Body 1 is completely free (6 DOF), body 0 is grounded
assert result.jacobian_internal_dof > 0
assert not result.is_rigid
def test_ungrounded(self) -> None:
bodies = _two_bodies()
result = analyze_assembly(bodies, [])
assert result.combinatorial_classification == "underconstrained"
class TestReturnType:
"""Verify the return object is a proper ConstraintAnalysis."""
def test_instance(self) -> None:
bodies = _two_bodies()
joints = [Joint(0, 0, 1, JointType.FIXED)]
result = analyze_assembly(bodies, joints)
assert isinstance(result, ConstraintAnalysis)
def test_per_edge_results_populated(self) -> None:
bodies = _two_bodies()
joints = [Joint(0, 0, 1, JointType.REVOLUTE)]
result = analyze_assembly(bodies, joints)
assert len(result.per_edge_results) == 5

View File

@@ -0,0 +1,337 @@
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
from __future__ import annotations
import json
import os
import subprocess
import sys
from typing import TYPE_CHECKING
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
_derive_shard_seed,
_parse_scalar,
parse_simple_yaml,
)
if TYPE_CHECKING:
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# DatasetConfig
# ---------------------------------------------------------------------------
class TestDatasetConfig:
"""DatasetConfig construction and defaults."""
def test_defaults(self) -> None:
cfg = DatasetConfig()
assert cfg.num_assemblies == 100_000
assert cfg.seed == 42
assert cfg.shard_size == 1000
assert cfg.num_workers == 4
def test_from_dict_flat(self) -> None:
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
cfg = DatasetConfig.from_dict(d)
assert cfg.num_assemblies == 500
assert cfg.seed == 123
def test_from_dict_nested_body_count(self) -> None:
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 3
assert cfg.body_count_max == 20
def test_from_dict_flat_body_count(self) -> None:
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 5
assert cfg.body_count_max == 30
def test_from_dict_complexity_distribution(self) -> None:
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
cfg = DatasetConfig.from_dict(d)
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
def test_from_dict_templates(self) -> None:
d: dict[str, Any] = {"templates": ["chain", "tree"]}
cfg = DatasetConfig.from_dict(d)
assert cfg.templates == ["chain", "tree"]
# ---------------------------------------------------------------------------
# Minimal YAML parser
# ---------------------------------------------------------------------------
class TestParseScalar:
"""_parse_scalar handles different value types."""
def test_int(self) -> None:
assert _parse_scalar("42") == 42
def test_float(self) -> None:
assert _parse_scalar("3.14") == 3.14
def test_bool_true(self) -> None:
assert _parse_scalar("true") is True
def test_bool_false(self) -> None:
assert _parse_scalar("false") is False
def test_string(self) -> None:
assert _parse_scalar("hello") == "hello"
def test_inline_comment(self) -> None:
assert _parse_scalar("0.4 # some comment") == 0.4
class TestParseSimpleYaml:
"""parse_simple_yaml handles the synthetic.yaml format."""
def test_flat_scalars(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
result = parse_simple_yaml(str(yaml_file))
assert result["name"] == "test"
assert result["num"] == 42
assert result["ratio"] == 0.5
def test_nested_dict(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
result = parse_simple_yaml(str(yaml_file))
assert result["body_count"] == {"min": 2, "max": 50}
def test_list(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
result = parse_simple_yaml(str(yaml_file))
assert result["templates"] == ["chain", "tree", "loop"]
def test_inline_comments(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
result = parse_simple_yaml(str(yaml_file))
assert result["dist"]["simple"] == 0.4
def test_synthetic_yaml(self) -> None:
"""Parse the actual project config."""
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
assert result["name"] == "synthetic"
assert result["num_assemblies"] == 100000
assert isinstance(result["complexity_distribution"], dict)
assert isinstance(result["templates"], list)
assert result["shard_size"] == 1000
# ---------------------------------------------------------------------------
# Shard seed derivation
# ---------------------------------------------------------------------------
class TestShardSeedDerivation:
"""_derive_shard_seed is deterministic and unique per shard."""
def test_deterministic(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 0)
assert s1 == s2
def test_different_shards(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 1)
assert s1 != s2
def test_different_global_seeds(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(99, 0)
assert s1 != s2
# ---------------------------------------------------------------------------
# DatasetGenerator — small end-to-end tests
# ---------------------------------------------------------------------------
class TestDatasetGenerator:
"""End-to-end tests with small datasets."""
def test_small_generation(self, tmp_path: Path) -> None:
"""Generate 10 examples in a single shard."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
assert shards_dir.exists()
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 1
index_file = tmp_path / "output" / "index.json"
assert index_file.exists()
index = json.loads(index_file.read_text())
assert index["total_assemblies"] == 10
stats_file = tmp_path / "output" / "stats.json"
assert stats_file.exists()
def test_multi_shard(self, tmp_path: Path) -> None:
"""Generate 20 examples across 2 shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 2
def test_resume_skips_completed(self, tmp_path: Path) -> None:
"""Resume skips already-completed shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
# Record shard modification times
shards_dir = tmp_path / "output" / "shards"
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
# Remove stats (simulate incomplete) and re-run
(tmp_path / "output" / "stats.json").unlink()
DatasetGenerator(cfg).run()
# Shards should NOT have been regenerated
for p in shards_dir.glob("shard_*"):
assert p.stat().st_mtime == mtimes[p.name]
# Stats should be regenerated
assert (tmp_path / "output" / "stats.json").exists()
def test_checkpoint_removed(self, tmp_path: Path) -> None:
"""Checkpoint file is cleaned up after completion."""
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / "output"),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
checkpoint = tmp_path / "output" / ".checkpoint.json"
assert not checkpoint.exists()
def test_stats_structure(self, tmp_path: Path) -> None:
"""stats.json has expected top-level keys."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
assert stats["total_examples"] == 10
assert "classification_distribution" in stats
assert "body_count_histogram" in stats
assert "joint_type_distribution" in stats
assert "dof_statistics" in stats
assert "geometric_degeneracy" in stats
assert "rigidity" in stats
def test_index_structure(self, tmp_path: Path) -> None:
"""index.json has expected format."""
cfg = DatasetConfig(
num_assemblies=15,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
index = json.loads((tmp_path / "output" / "index.json").read_text())
assert index["format_version"] == 1
assert index["total_assemblies"] == 15
assert index["total_shards"] == 2
assert "shards" in index
for _name, info in index["shards"].items():
assert "start_id" in info
assert "count" in info
def test_deterministic_output(self, tmp_path: Path) -> None:
"""Same seed produces same results."""
for run_dir in ("run1", "run2"):
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / run_dir),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
assert s1["total_examples"] == s2["total_examples"]
assert s1["classification_distribution"] == s2["classification_distribution"]
# ---------------------------------------------------------------------------
# CLI integration test
# ---------------------------------------------------------------------------
class TestCLI:
"""Run the script via subprocess."""
def test_argparse_mode(self, tmp_path: Path) -> None:
result = subprocess.run(
[
sys.executable,
"scripts/generate_synthetic.py",
"--num-assemblies",
"5",
"--output-dir",
str(tmp_path / "cli_out"),
"--shard-size",
"5",
"--num-workers",
"1",
"--seed",
"42",
],
capture_output=True,
text=True,
cwd="/home/developer",
timeout=120,
env={**os.environ, "PYTHONPATH": "/home/developer"},
)
assert result.returncode == 0, (
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
)
assert (tmp_path / "cli_out" / "index.json").exists()
assert (tmp_path / "cli_out" / "stats.json").exists()

View File

@@ -0,0 +1,682 @@
"""Tests for solver.datagen.generator -- synthetic assembly generation."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.types import JointType
# ---------------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ---------------------------------------------------------------------------
class TestChainAssembly:
"""generate_chain_assembly produces valid underconstrained chains."""
def test_returns_three_tuple(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
bodies, joints, _analysis = gen.generate_chain_assembly(4)
assert len(bodies) == 4
assert len(joints) == 3
def test_chain_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_classification == "underconstrained"
def test_chain_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
bodies, _, _ = gen.generate_chain_assembly(5)
ids = [b.body_id for b in bodies]
assert ids == [0, 1, 2, 3, 4]
def test_chain_joint_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(4)
for i, j in enumerate(joints):
assert j.body_a == i
assert j.body_b == i + 1
def test_chain_custom_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(
3,
joint_type=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
class TestRigidAssembly:
"""generate_rigid_assembly produces rigid assemblies."""
def test_rigid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_rigid_assembly(4)
assert analysis.is_rigid
def test_spanning_tree_structure(self) -> None:
"""n bodies should have at least n-1 joints (spanning tree)."""
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_rigid_assembly(5)
assert len(joints) >= len(bodies) - 1
def test_deterministic(self) -> None:
"""Same seed produces same results."""
g1 = SyntheticAssemblyGenerator(seed=99)
g2 = SyntheticAssemblyGenerator(seed=99)
_, j1, a1 = g1.generate_rigid_assembly(4)
_, j2, a2 = g2.generate_rigid_assembly(4)
assert a1.jacobian_rank == a2.jacobian_rank
assert len(j1) == len(j2)
class TestOverconstrainedAssembly:
"""generate_overconstrained_assembly adds redundant constraints."""
def test_has_redundant(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_overconstrained_assembly(
4,
extra_joints=2,
)
assert analysis.combinatorial_redundant > 0
def test_extra_joints_added(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints_base, _ = gen.generate_rigid_assembly(4)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, joints_over, _ = gen2.generate_overconstrained_assembly(
4,
extra_joints=3,
)
# Overconstrained has base joints + extra
assert len(joints_over) > len(joints_base)
# ---------------------------------------------------------------------------
# New topology generators
# ---------------------------------------------------------------------------
class TestTreeAssembly:
"""generate_tree_assembly produces tree-structured assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_tree_assembly(6)
assert analysis.combinatorial_classification == "underconstrained"
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(
10,
branching_factor=2,
)
assert len(bodies) == 10
assert len(joints) == 9
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
used = {j.joint_type for j in joints}
# With 9 joints and 3 types, very likely to use at least 2
assert len(used) >= 2
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(
5,
joint_types=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(7)
assert [b.body_id for b in bodies] == list(range(7))
class TestLoopAssembly:
"""generate_loop_assembly produces closed-loop assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_loop_assembly(5)
assert len(bodies) == 5
assert len(joints) == 5 # n joints for n bodies
def test_has_redundancy(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_loop_assembly(5)
assert analysis.combinatorial_redundant > 0
def test_wrap_around_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(4)
edges = {(j.body_a, j.body_b) for j in joints}
assert (0, 1) in edges
assert (1, 2) in edges
assert (2, 3) in edges
assert (3, 0) in edges # wrap-around
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 3"):
gen.generate_loop_assembly(2)
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.FIXED]
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
used = {j.joint_type for j in joints}
assert len(used) >= 2
class TestStarAssembly:
"""generate_star_assembly produces hub-and-spoke assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_star_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_all_joints_connect_to_hub(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(6)
for j in joints:
assert j.body_a == 0 or j.body_b == 0
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_star_assembly(5)
assert analysis.combinatorial_classification == "underconstrained"
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 2"):
gen.generate_star_assembly(1)
def test_hub_at_origin(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(4)
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
class TestMixedAssembly:
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(
8,
edge_density=0.3,
)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=0.0,
)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=1.5)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=-0.1)
def test_no_duplicate_edges(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
assert len(edges) == len(set(edges))
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=1.0,
)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
class TestAxisStrategy:
"""Axis sampling strategies produce valid unit vectors."""
def test_cardinal_axis_from_six(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
expected = {
(1, 0, 0),
(-1, 0, 0),
(0, 1, 0),
(0, -1, 0),
(0, 0, 1),
(0, 0, -1),
}
assert axes == expected
def test_random_axis_unit_norm(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
for _ in range(50):
axis = gen._sample_axis("random")
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
def test_near_parallel_close_to_base(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
base = np.array([0.0, 0.0, 1.0])
for _ in range(50):
axis = gen._near_parallel_axis(base)
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
assert np.dot(axis, base) > 0.95
def test_sample_axis_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("cardinal")
cardinals = [
np.array(v, dtype=float)
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
]
assert any(np.allclose(axis, c) for c in cardinals)
def test_sample_axis_near_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("near_parallel")
z = np.array([0.0, 0.0, 1.0])
assert np.dot(axis, z) > 0.95
# ---------------------------------------------------------------------------
# Geometric diversity: orientations
# ---------------------------------------------------------------------------
class TestRandomOrientations:
"""Bodies should have non-identity orientations."""
def test_bodies_have_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(5)
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
assert non_identity >= 3
def test_orientations_are_valid_rotations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(6)
for b in bodies:
r = b.orientation
# R^T R == I
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
# det(R) == 1
assert abs(np.linalg.det(r) - 1.0) < 1e-10
def test_all_generators_set_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
# Chain
bodies, _, _ = gen.generate_chain_assembly(3)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Loop
bodies, _, _ = gen.generate_loop_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Mixed
bodies, _, _ = gen.generate_mixed_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# ---------------------------------------------------------------------------
# Geometric diversity: grounded parameter
# ---------------------------------------------------------------------------
class TestGroundedParameter:
"""Grounded parameter controls ground_body in analysis."""
def test_chain_grounded_default(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_dof >= 0
def test_chain_floating(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(
4,
grounded=False,
)
# Floating: 6 trivial DOF not subtracted by ground
assert analysis.combinatorial_dof >= 6
def test_floating_vs_grounded_dof_difference(self) -> None:
gen1 = SyntheticAssemblyGenerator(seed=42)
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
# Floating should have higher DOF due to missing ground constraint
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
@pytest.mark.parametrize(
"gen_method",
[
"generate_chain_assembly",
"generate_rigid_assembly",
"generate_tree_assembly",
"generate_loop_assembly",
"generate_star_assembly",
"generate_mixed_assembly",
],
)
def test_all_generators_accept_grounded(
self,
gen_method: str,
) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
method = getattr(gen, gen_method)
n = 4
# Should not raise
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
method(n, grounded=False)
else:
method(n, grounded=False)
# ---------------------------------------------------------------------------
# Geometric diversity: parallel axis injection
# ---------------------------------------------------------------------------
class TestParallelAxisInjection:
"""parallel_axis_prob causes shared axis direction."""
def test_parallel_axes_similar(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
# Near-parallel: |dot| close to 1
assert abs(np.dot(j.axis, base)) > 0.9
def test_zero_prob_no_forced_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=0.0,
)
base = joints[0].axis
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
# With 5 random axes, extremely unlikely all are parallel
assert min(dots) < 0.95
def test_parallel_on_loop(self) -> None:
"""Parallel axes on a loop assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
def test_parallel_on_star(self) -> None:
"""Parallel axes on a star assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
class TestComplexityTiers:
"""Complexity tier parameter on batch generation."""
def test_simple_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="simple")
lo, hi = COMPLEXITY_RANGES["simple"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_medium_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_complex_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(3, complexity_tier="complex")
lo, hi = COMPLEXITY_RANGES["complex"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
n_bodies_range=(2, 3),
complexity_tier="medium",
)
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
# ---------------------------------------------------------------------------
# Training batch
# ---------------------------------------------------------------------------
class TestTrainingBatch:
"""generate_training_batch produces well-structured examples."""
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"grounded",
"n_bodies",
"n_joints",
"body_positions",
"body_orientations",
"joints",
"joint_labels",
"labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
VALID_GEN_TYPES: ClassVar[set[str]] = {
"chain",
"rigid",
"overconstrained",
"tree",
"loop",
"star",
"mixed",
}
def test_batch_size(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20)
assert len(batch) == 20
def test_example_keys(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert set(ex.keys()) == self.EXPECTED_KEYS
def test_example_ids_sequential(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(15)
assert [ex["example_id"] for ex in batch] == list(range(15))
def test_generator_type_valid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(50)
for ex in batch:
assert ex["generator_type"] in self.VALID_GEN_TYPES
def test_generator_type_diversity(self) -> None:
"""100-sample batch should use at least 5 of 7 generator types."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100)
types = {ex["generator_type"] for ex in batch}
assert len(types) >= 5
def test_default_body_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
# default (3, 8), but loop/star may clamp
assert 2 <= ex["n_bodies"] <= 7
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
for label in ex["joint_labels"].values():
total = label["independent_constraints"] + label["redundant_constraints"]
assert total == label["total_constraints"]
def test_body_orientations_present(self) -> None:
"""Each example includes body_orientations as 3x3 lists."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
orients = ex["body_orientations"]
assert len(orients) == ex["n_bodies"]
for o in orients:
assert len(o) == 3
assert len(o[0]) == 3
def test_labels_structure(self) -> None:
"""Each example has labels dict with expected sub-keys."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
labels = ex["labels"]
assert "per_constraint" in labels
assert "per_joint" in labels
assert "per_body" in labels
assert "assembly" in labels
def test_grounded_field_present(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert isinstance(ex["grounded"], bool)
# ---------------------------------------------------------------------------
# Batch grounded ratio
# ---------------------------------------------------------------------------
class TestBatchGroundedRatio:
"""grounded_ratio controls the mix in batch generation."""
def test_all_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
assert all(ex["grounded"] for ex in batch)
def test_none_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
assert not any(ex["grounded"] for ex in batch)
def test_mixed_ratio(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
grounded_count = sum(1 for ex in batch if ex["grounded"])
# With 100 samples and p=0.5, should be roughly 50 +/- 20
assert 20 < grounded_count < 80
def test_batch_axis_strategy_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
axis_strategy="cardinal",
)
assert len(batch) == 10
def test_batch_parallel_axis_prob(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
parallel_axis_prob=0.5,
)
assert len(batch) == 10
# ---------------------------------------------------------------------------
# Seed reproducibility
# ---------------------------------------------------------------------------
class TestSeedReproducibility:
"""Different seeds produce different results."""
def test_different_seeds_differ(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=1)
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
b2 = g2.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]
r2 = [ex["is_rigid"] for ex in b2]
assert c1 != c2 or r1 != r2
def test_same_seed_identical(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=123)
g2 = SyntheticAssemblyGenerator(seed=123)
b1, j1, _ = g1.generate_tree_assembly(5)
b2, j2, _ = g2.generate_tree_assembly(5)
for a, b in zip(b1, b2, strict=True):
np.testing.assert_array_almost_equal(a.position, b.position)
assert len(j1) == len(j2)

View File

@@ -0,0 +1,267 @@
"""Tests for solver.datagen.jacobian -- Jacobian rank verification."""
from __future__ import annotations
import numpy as np
import pytest
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.types import Joint, JointType, RigidBody
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _two_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
]
def _three_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestJacobianShape:
"""Verify Jacobian matrix dimensions for each joint type."""
@pytest.mark.parametrize(
"joint_type,expected_rows",
[
(JointType.FIXED, 6),
(JointType.REVOLUTE, 5),
(JointType.CYLINDRICAL, 4),
(JointType.SLIDER, 5),
(JointType.BALL, 3),
(JointType.PLANAR, 3),
(JointType.SCREW, 5),
(JointType.UNIVERSAL, 4),
(JointType.PARALLEL, 3),
(JointType.PERPENDICULAR, 1),
(JointType.DISTANCE, 1),
],
)
def test_row_count(self, joint_type: JointType, expected_rows: int) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
joint = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=joint_type,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
n_added = v.add_joint_constraints(joint)
assert n_added == expected_rows
j = v.get_jacobian()
assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols
class TestNumericalRank:
"""Numerical rank checks for known configurations."""
def test_fixed_joint_rank_six(self) -> None:
"""Fixed joint between 2 bodies: rank = 6."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 6
def test_revolute_joint_rank_five(self) -> None:
"""Revolute joint between 2 bodies: rank = 5."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 5
def test_ball_joint_rank_three(self) -> None:
"""Ball joint between 2 bodies: rank = 3."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.BALL,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 3
def test_empty_jacobian_rank_zero(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
assert v.numerical_rank() == 0
class TestParallelAxesDegeneracy:
"""Parallel revolute axes create geometric dependencies."""
def _four_body_loop(self) -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([2.0, 2.0, 0.0])),
RigidBody(3, position=np.array([0.0, 2.0, 0.0])),
]
def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]:
pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])]
return [
Joint(
joint_id=i,
body_a=a,
body_b=b,
joint_type=JointType.REVOLUTE,
anchor_a=np.array(anc, dtype=float),
anchor_b=np.array(anc, dtype=float),
axis=axes[i],
)
for i, (a, b, anc) in enumerate(pairs)
]
def test_parallel_has_lower_rank(self) -> None:
"""4-body closed loop: all-parallel revolute axes produce lower
Jacobian rank than mixed axes due to geometric dependency."""
bodies = self._four_body_loop()
z_axis = np.array([0.0, 0.0, 1.0])
# All axes parallel to Z
v_par = JacobianVerifier(bodies)
for j in self._loop_joints([z_axis] * 4):
v_par.add_joint_constraints(j)
rank_par = v_par.numerical_rank()
# Mixed axes
mixed = [
np.array([0.0, 0.0, 1.0]),
np.array([0.0, 1.0, 0.0]),
np.array([0.0, 0.0, 1.0]),
np.array([1.0, 0.0, 0.0]),
]
v_mix = JacobianVerifier(bodies)
for j in self._loop_joints(mixed):
v_mix.add_joint_constraints(j)
rank_mix = v_mix.numerical_rank()
assert rank_par < rank_mix
class TestFindDependencies:
"""Dependency detection."""
def test_fixed_joint_no_dependencies(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.find_dependencies() == []
def test_duplicate_fixed_has_dependencies(self) -> None:
"""Two fixed joints on same pair: second is fully dependent."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
for jid in range(2):
v.add_joint_constraints(
Joint(
joint_id=jid,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
)
deps = v.find_dependencies()
assert len(deps) == 6 # Second fixed joint entirely redundant
def test_empty_no_dependencies(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
assert v.find_dependencies() == []
class TestRowLabels:
"""Row label metadata."""
def test_labels_match_rows(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=7,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
v.add_joint_constraints(j)
assert len(v.row_labels) == 5
assert all(lab["joint_id"] == 7 for lab in v.row_labels)
class TestPerpendicularPair:
"""Internal _perpendicular_pair utility."""
@pytest.mark.parametrize(
"axis",
[
np.array([1.0, 0.0, 0.0]),
np.array([0.0, 1.0, 0.0]),
np.array([0.0, 0.0, 1.0]),
np.array([1.0, 1.0, 1.0]) / np.sqrt(3),
],
)
def test_orthonormal(self, axis: np.ndarray) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
t1, t2 = v._perpendicular_pair(axis)
# All unit length
np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12)
np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12)
# Mutually perpendicular
np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12)
np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12)
np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12)

View File

@@ -0,0 +1,346 @@
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
from __future__ import annotations
import json
import numpy as np
from solver.datagen.labeling import (
label_assembly,
)
from solver.datagen.types import Joint, JointType, RigidBody
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
def _make_joint(
jid: int,
a: int,
b: int,
jtype: JointType,
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
) -> Joint:
return Joint(
joint_id=jid,
body_a=a,
body_b=b,
joint_type=jtype,
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
axis=np.array(axis),
)
# ---------------------------------------------------------------------------
# Per-constraint labels
# ---------------------------------------------------------------------------
class TestConstraintLabels:
"""Per-constraint labels combine pebble game and Jacobian results."""
def test_fixed_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 6
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_revolute_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 5
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_chain_constraint_count(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 10 # 5 + 5
def test_constraint_joint_ids(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
assert len(j0_constraints) == 5 # revolute
assert len(j1_constraints) == 3 # ball
def test_overconstrained_has_pebble_redundant(self) -> None:
"""Triangle with revolute joints: some constraints redundant."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
assert pebble_redundant > 0
# ---------------------------------------------------------------------------
# Per-joint labels
# ---------------------------------------------------------------------------
class TestJointLabels:
"""Per-joint aggregated labels."""
def test_fixed_joint_counts(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_joint) == 1
jl = labels.per_joint[0]
assert jl.joint_id == 0
assert jl.independent_count == 6
assert jl.redundant_count == 0
assert jl.total == 6
def test_overconstrained_has_redundant_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
assert total_redundant > 0
def test_joint_total_equals_dof(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.BALL)]
labels = label_assembly(bodies, joints, ground_body=0)
jl = labels.per_joint[0]
assert jl.total == 3 # ball has 3 DOF
# ---------------------------------------------------------------------------
# Per-body DOF labels
# ---------------------------------------------------------------------------
class TestBodyDofLabels:
"""Per-body DOF signatures from nullspace projection."""
def test_fixed_joint_grounded_both_zero(self) -> None:
"""Two bodies + fixed joint + grounded: both fully constrained."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert bl.translational_dof == 0
assert bl.rotational_dof == 0
def test_revolute_has_rotational_dof(self) -> None:
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b1 = next(b for b in labels.per_body if b.body_id == 1)
# Revolute allows 1 rotation DOF
assert b1.rotational_dof >= 1
def test_dof_bounds(self) -> None:
"""All DOF values should be in [0, 3]."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert 0 <= bl.translational_dof <= 3
assert 0 <= bl.rotational_dof <= 3
def test_floating_more_dof_than_grounded(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
grounded = label_assembly(bodies, joints, ground_body=0)
floating = label_assembly(bodies, joints, ground_body=None)
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
assert f_total > g_total
def test_grounded_body_zero_dof(self) -> None:
"""The grounded body should have 0 DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b0 = next(b for b in labels.per_body if b.body_id == 0)
assert b0.translational_dof == 0
assert b0.rotational_dof == 0
def test_body_count_matches(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_body) == 3
# ---------------------------------------------------------------------------
# Assembly label
# ---------------------------------------------------------------------------
class TestAssemblyLabel:
"""Assembly-wide summary labels."""
def test_underconstrained_chain(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "underconstrained"
assert labels.assembly.is_rigid is False
def test_well_constrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "well-constrained"
assert labels.assembly.is_rigid is True
assert labels.assembly.is_minimally_rigid is True
def test_overconstrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.redundant_count > 0
def test_has_degeneracy_with_parallel_axes(self) -> None:
"""Parallel revolute axes in a loop create geometric degeneracy."""
z_axis = (0.0, 0.0, 1.0)
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.has_degeneracy is True
# ---------------------------------------------------------------------------
# Serialization
# ---------------------------------------------------------------------------
class TestToDict:
"""to_dict produces JSON-serializable output."""
def test_top_level_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d.keys()) == {
"per_constraint",
"per_joint",
"per_body",
"assembly",
}
def test_per_constraint_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
for item in d["per_constraint"]:
assert set(item.keys()) == {
"joint_id",
"constraint_idx",
"pebble_independent",
"jacobian_independent",
}
def test_assembly_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d["assembly"].keys()) == {
"classification",
"total_dof",
"redundant_count",
"is_rigid",
"is_minimally_rigid",
"has_degeneracy",
}
def test_json_serializable(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
# Should not raise
serialized = json.dumps(d)
assert isinstance(serialized, str)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestLabelAssemblyEdgeCases:
"""Edge cases for label_assembly."""
def test_no_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
labels = label_assembly(bodies, [], ground_body=0)
assert len(labels.per_constraint) == 0
assert len(labels.per_joint) == 0
assert labels.assembly.classification == "underconstrained"
# Non-ground body should be fully free
b1 = next(b for b in labels.per_body if b.body_id == 1)
assert b1.translational_dof == 3
assert b1.rotational_dof == 3
def test_no_joints_floating(self) -> None:
bodies = _make_bodies((0, 0, 0))
labels = label_assembly(bodies, [], ground_body=None)
assert len(labels.per_body) == 1
assert labels.per_body[0].translational_dof == 3
assert labels.per_body[0].rotational_dof == 3
def test_analysis_embedded(self) -> None:
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
analysis = labels.analysis
assert hasattr(analysis, "combinatorial_classification")
assert hasattr(analysis, "jacobian_rank")
assert hasattr(analysis, "is_rigid")

View File

@@ -0,0 +1,206 @@
"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game."""
from __future__ import annotations
import itertools
import numpy as np
import pytest
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import Joint, JointType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint:
"""Shorthand for a revolute joint between bodies *a* and *b*."""
if axis is None:
axis = np.array([0.0, 0.0, 1.0])
return Joint(
joint_id=jid,
body_a=a,
body_b=b,
joint_type=JointType.REVOLUTE,
axis=axis,
)
def _fixed(jid: int, a: int, b: int) -> Joint:
return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAddBody:
"""Body registration basics."""
def test_single_body_six_pebbles(self) -> None:
pg = PebbleGame3D()
pg.add_body(0)
assert pg.state.free_pebbles[0] == 6
def test_duplicate_body_no_op(self) -> None:
pg = PebbleGame3D()
pg.add_body(0)
pg.add_body(0)
assert pg.state.free_pebbles[0] == 6
def test_multiple_bodies(self) -> None:
pg = PebbleGame3D()
for i in range(5):
pg.add_body(i)
assert pg.get_dof() == 30 # 5 * 6
class TestAddJoint:
"""Joint insertion and DOF accounting."""
def test_revolute_removes_five_dof(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
assert len(results) == 5 # 5 scalar constraints
assert all(r["independent"] for r in results)
# 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles
assert pg.get_dof() == 7
def test_fixed_removes_six_dof(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_fixed(0, 0, 1))
assert len(results) == 6
assert all(r["independent"] for r in results)
assert pg.get_dof() == 6
def test_ball_removes_three_dof(self) -> None:
pg = PebbleGame3D()
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL)
results = pg.add_joint(j)
assert len(results) == 3
assert all(r["independent"] for r in results)
assert pg.get_dof() == 9
class TestTwoBodiesRevolute:
"""Two bodies connected by a revolute -- demo scenario 1."""
def test_internal_dof(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
# Total DOF = 7, internal = 7 - 6 = 1
assert pg.get_internal_dof() == 1
def test_not_rigid(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
assert not pg.is_rigid()
def test_classification(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
assert pg.classify_assembly() == "underconstrained"
class TestTwoBodiesFixed:
"""Two bodies + fixed joint -- demo scenario 2."""
def test_zero_internal_dof(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.get_internal_dof() == 0
def test_rigid(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.is_rigid()
def test_well_constrained(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.classify_assembly() == "well-constrained"
class TestTriangleRevolute:
"""Triangle of 3 bodies with revolute joints -- demo scenario 3."""
@pytest.fixture()
def pg(self) -> PebbleGame3D:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
pg.add_joint(_revolute(1, 1, 2))
pg.add_joint(_revolute(2, 2, 0))
return pg
def test_has_redundant_edges(self, pg: PebbleGame3D) -> None:
assert pg.get_redundant_count() > 0
def test_classification_overconstrained(self, pg: PebbleGame3D) -> None:
# 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed)
assert pg.classify_assembly() in ("overconstrained", "mixed")
def test_rigid(self, pg: PebbleGame3D) -> None:
assert pg.is_rigid()
class TestChainNotRigid:
"""A serial chain of 4 bodies with revolute joints is never rigid."""
def test_chain_underconstrained(self) -> None:
pg = PebbleGame3D()
for i in range(3):
pg.add_joint(_revolute(i, i, i + 1))
assert not pg.is_rigid()
assert pg.classify_assembly() == "underconstrained"
def test_chain_internal_dof(self) -> None:
pg = PebbleGame3D()
for i in range(3):
pg.add_joint(_revolute(i, i, i + 1))
# 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3
assert pg.get_internal_dof() == 3
class TestEdgeResults:
"""Result dicts returned by add_joint."""
def test_result_keys(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"}
for r in results:
assert set(r.keys()) == expected_keys
def test_edge_ids_sequential(self) -> None:
pg = PebbleGame3D()
r1 = pg.add_joint(_revolute(0, 0, 1))
r2 = pg.add_joint(_revolute(1, 1, 2))
all_ids = [r["edge_id"] for r in r1 + r2]
assert all_ids == list(range(10))
def test_dof_remaining_monotonic(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
dofs = [r["dof_remaining"] for r in results]
# Should be non-increasing (each independent edge removes a pebble)
for a, b in itertools.pairwise(dofs):
assert a >= b
class TestGroundedClassification:
"""classify_assembly with grounded=True."""
def test_grounded_baseline_zero(self) -> None:
"""With grounded=True the baseline is 0 (not 6)."""
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
# Ungrounded: well-constrained (6 pebbles = baseline 6)
assert pg.classify_assembly(grounded=False) == "well-constrained"
# Grounded: the 6 remaining pebbles on body 1 exceed baseline 0,
# so the raw pebble game (without a virtual ground body) sees this
# as underconstrained. The analysis function handles this properly
# by adding a virtual ground body.
assert pg.classify_assembly(grounded=True) == "underconstrained"

163
tests/datagen/test_types.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for solver.datagen.types -- shared data types."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
PebbleState,
RigidBody,
)
class TestJointType:
"""JointType enum construction and DOF values."""
EXPECTED_DOF: ClassVar[dict[str, int]] = {
"FIXED": 6,
"REVOLUTE": 5,
"CYLINDRICAL": 4,
"SLIDER": 5,
"BALL": 3,
"PLANAR": 3,
"SCREW": 5,
"UNIVERSAL": 4,
"PARALLEL": 3,
"PERPENDICULAR": 1,
"DISTANCE": 1,
}
def test_member_count(self) -> None:
assert len(JointType) == 11
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
def test_dof_values(self, name: str, dof: int) -> None:
assert JointType[name].dof == dof
def test_access_by_name(self) -> None:
assert JointType["REVOLUTE"] is JointType.REVOLUTE
def test_value_is_tuple(self) -> None:
assert JointType.REVOLUTE.value == (1, 5)
assert JointType.REVOLUTE.dof == 5
class TestRigidBody:
"""RigidBody dataclass defaults and construction."""
def test_defaults(self) -> None:
body = RigidBody(body_id=0)
np.testing.assert_array_equal(body.position, np.zeros(3))
np.testing.assert_array_equal(body.orientation, np.eye(3))
assert body.local_anchors == {}
def test_custom_position(self) -> None:
pos = np.array([1.0, 2.0, 3.0])
body = RigidBody(body_id=7, position=pos)
np.testing.assert_array_equal(body.position, pos)
assert body.body_id == 7
def test_local_anchors_mutable(self) -> None:
body = RigidBody(body_id=0)
body.local_anchors["top"] = np.array([0.0, 0.0, 1.0])
assert "top" in body.local_anchors
def test_default_factory_isolation(self) -> None:
"""Each instance gets its own default containers."""
b1 = RigidBody(body_id=0)
b2 = RigidBody(body_id=1)
b1.local_anchors["x"] = np.zeros(3)
assert "x" not in b2.local_anchors
class TestJoint:
"""Joint dataclass defaults and construction."""
def test_defaults(self) -> None:
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE)
np.testing.assert_array_equal(j.anchor_a, np.zeros(3))
np.testing.assert_array_equal(j.anchor_b, np.zeros(3))
np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0]))
assert j.pitch == 0.0
def test_full_construction(self) -> None:
j = Joint(
joint_id=5,
body_a=2,
body_b=3,
joint_type=JointType.SCREW,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([2.0, 0.0, 0.0]),
axis=np.array([1.0, 0.0, 0.0]),
pitch=0.5,
)
assert j.joint_id == 5
assert j.joint_type is JointType.SCREW
assert j.pitch == 0.5
class TestPebbleState:
"""PebbleState dataclass defaults."""
def test_defaults(self) -> None:
s = PebbleState()
assert s.free_pebbles == {}
assert s.directed_edges == {}
assert s.independent_edges == set()
assert s.redundant_edges == set()
assert s.incoming == {}
assert s.outgoing == {}
def test_default_factory_isolation(self) -> None:
s1 = PebbleState()
s2 = PebbleState()
s1.free_pebbles[0] = 6
assert 0 not in s2.free_pebbles
class TestConstraintAnalysis:
"""ConstraintAnalysis dataclass construction."""
def test_construction(self) -> None:
ca = ConstraintAnalysis(
combinatorial_dof=6,
combinatorial_internal_dof=0,
combinatorial_redundant=0,
combinatorial_classification="well-constrained",
per_edge_results=[],
jacobian_rank=6,
jacobian_nullity=0,
jacobian_internal_dof=0,
numerically_dependent=[],
geometric_degeneracies=0,
is_rigid=True,
is_minimally_rigid=True,
)
assert ca.is_rigid is True
assert ca.is_minimally_rigid is True
assert ca.combinatorial_classification == "well-constrained"
def test_per_edge_results_typing(self) -> None:
"""per_edge_results accepts list[dict[str, Any]]."""
ca = ConstraintAnalysis(
combinatorial_dof=7,
combinatorial_internal_dof=1,
combinatorial_redundant=0,
combinatorial_classification="underconstrained",
per_edge_results=[{"edge_id": 0, "independent": True}],
jacobian_rank=5,
jacobian_nullity=1,
jacobian_internal_dof=1,
numerically_dependent=[],
geometric_degeneracies=0,
is_rigid=False,
is_minimally_rigid=False,
)
assert len(ca.per_edge_results) == 1
assert ca.per_edge_results[0]["edge_id"] == 0