Merge remote-tracking branch 'public/main'
# Conflicts: # .gitignore # README.md
This commit is contained in:
65
.gitea/workflows/ci.yaml
Normal file
65
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,65 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main]
|
||||
pull_request:
|
||||
branches: [main]
|
||||
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install ruff mypy
|
||||
pip install -e ".[dev]" || pip install ruff mypy numpy
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check solver/ freecad/ tests/ scripts/
|
||||
|
||||
- name: Ruff format check
|
||||
run: ruff format --check solver/ freecad/ tests/ scripts/
|
||||
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install mypy numpy
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install torch-geometric
|
||||
pip install -e ".[dev]"
|
||||
|
||||
- name: Mypy
|
||||
run: mypy solver/ freecad/
|
||||
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||
pip install torch-geometric
|
||||
pip install -e ".[train,dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: pytest tests/ freecad/tests/ -v --tb=short
|
||||
77
.gitignore
vendored
77
.gitignore
vendored
@@ -1,44 +1,83 @@
|
||||
# Prerequisites
|
||||
# C++ compiled objects
|
||||
*.d
|
||||
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
*.lo
|
||||
*.o
|
||||
*.obj
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
# C++ libraries
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
|
||||
# Fortran module files
|
||||
*.mod
|
||||
*.smod
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai
|
||||
*.la
|
||||
*.a
|
||||
*.lib
|
||||
|
||||
# Executables
|
||||
# C++ executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
|
||||
.vs
|
||||
# C++ build
|
||||
build/
|
||||
cmake-build-debug/
|
||||
.vs/
|
||||
x64/
|
||||
temp/
|
||||
|
||||
# OndselSolver test artifacts
|
||||
*.bak
|
||||
assembly.asmt
|
||||
|
||||
build
|
||||
cmake-build-debug
|
||||
.idea
|
||||
temp/
|
||||
/testapp/draggingBackhoe.log
|
||||
/testapp/runPreDragBackhoe.asmt
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# mypy / ruff / pytest
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Data (large files tracked separately)
|
||||
data/synthetic/*.pt
|
||||
data/fusion360/*.json
|
||||
data/fusion360/*.step
|
||||
data/processed/*.pt
|
||||
!data/**/.gitkeep
|
||||
|
||||
# Model checkpoints
|
||||
*.ckpt
|
||||
*.pth
|
||||
*.onnx
|
||||
*.torchscript
|
||||
|
||||
# Experiment tracking
|
||||
wandb/
|
||||
runs/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Environment
|
||||
.env
|
||||
|
||||
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- torch>=2.2
|
||||
- numpy>=1.26
|
||||
args: [--ignore-missing-imports]
|
||||
|
||||
- repo: https://github.com/compilerla/conventional-pre-commit
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: conventional-pre-commit
|
||||
stages: [commit-msg]
|
||||
args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert]
|
||||
61
Dockerfile
Normal file
61
Dockerfile
Normal file
@@ -0,0 +1,61 @@
|
||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# System deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.11 python3.11-venv python3.11-dev python3-pip \
|
||||
git wget curl \
|
||||
# FreeCAD headless deps
|
||||
freecad \
|
||||
libgl1-mesa-glx libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
|
||||
|
||||
# Create venv
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install PyTorch with CUDA
|
||||
RUN pip install --no-cache-dir \
|
||||
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# Install PyG
|
||||
RUN pip install --no-cache-dir \
|
||||
torch-geometric \
|
||||
pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
|
||||
-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install project
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]" || true
|
||||
|
||||
COPY . .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
FROM base AS cpu
|
||||
|
||||
# CPU-only variant (for CI and non-GPU environments)
|
||||
FROM python:3.11-slim AS cpu-only
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git freecad libgl1-mesa-glx libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN pip install --no-cache-dir torch-geometric
|
||||
|
||||
COPY . .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||
|
||||
CMD ["pytest", "tests/", "-v"]
|
||||
48
Makefile
Normal file
48
Makefile
Normal file
@@ -0,0 +1,48 @@
|
||||
.PHONY: train test lint data-gen export format type-check install dev clean help
|
||||
|
||||
PYTHON ?= python
|
||||
PYTEST ?= pytest
|
||||
RUFF ?= ruff
|
||||
MYPY ?= mypy
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
|
||||
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
install: ## Install core dependencies
|
||||
pip install -e .
|
||||
|
||||
dev: ## Install all dependencies including dev tools
|
||||
pip install -e ".[train,dev]"
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
|
||||
train: ## Run training (pass CONFIG=path/to/config.yaml)
|
||||
$(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG))
|
||||
|
||||
test: ## Run test suite
|
||||
$(PYTEST) tests/ freecad/tests/ -v --tb=short
|
||||
|
||||
lint: ## Run ruff linter
|
||||
$(RUFF) check solver/ freecad/ tests/ scripts/
|
||||
|
||||
format: ## Format code with ruff
|
||||
$(RUFF) format solver/ freecad/ tests/ scripts/
|
||||
$(RUFF) check --fix solver/ freecad/ tests/ scripts/
|
||||
|
||||
type-check: ## Run mypy type checker
|
||||
$(MYPY) solver/ freecad/
|
||||
|
||||
data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml)
|
||||
$(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG))
|
||||
|
||||
export: ## Export trained model for deployment
|
||||
$(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL))
|
||||
|
||||
clean: ## Remove build artifacts and caches
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/
|
||||
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
|
||||
check: lint type-check test ## Run all checks (lint, type-check, test)
|
||||
94
README.md
94
README.md
@@ -1,7 +1,91 @@
|
||||
# MbDCode
|
||||
Assembly Constraints and Multibody Dynamics code
|
||||
# Kindred Solver
|
||||
|
||||
Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited)
|
||||
Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer.
|
||||
|
||||
The MbD theory is at
|
||||
https://github.com/Ondsel-Development/MbDTheory
|
||||
## Components
|
||||
|
||||
### OndselSolver (C++)
|
||||
|
||||
Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver.
|
||||
|
||||
- Source: `OndselSolver/`
|
||||
- Entry point: `OndselSolverMain/`
|
||||
- Tests: `tests/`, `testapp/`
|
||||
- Build: CMake
|
||||
|
||||
**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory)
|
||||
|
||||
#### Building
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
### ML Solver Layer (Python)
|
||||
|
||||
Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns.
|
||||
|
||||
- Core library: `solver/`
|
||||
- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling)
|
||||
- Model architectures: `solver/models/` (GIN, GAT, NNConv)
|
||||
- Training: `solver/training/`
|
||||
- Inference: `solver/inference/`
|
||||
- FreeCAD integration: `freecad/`
|
||||
- Configuration: `configs/` (Hydra)
|
||||
|
||||
#### Setup
|
||||
|
||||
```bash
|
||||
pip install -e ".[train,dev]"
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
make help # show all targets
|
||||
make dev # install all deps + pre-commit hooks
|
||||
make test # run tests
|
||||
make lint # run ruff linter
|
||||
make check # lint + type-check + test
|
||||
make data-gen # generate synthetic data
|
||||
make train # run training
|
||||
make export # export model
|
||||
```
|
||||
|
||||
Docker is also supported:
|
||||
|
||||
```bash
|
||||
docker compose up train # GPU training
|
||||
docker compose up test # run tests
|
||||
docker compose up data-gen # generate synthetic data
|
||||
```
|
||||
|
||||
## Repository structure
|
||||
|
||||
```
|
||||
kindred-solver/
|
||||
├── OndselSolver/ # C++ numerical solver library
|
||||
├── OndselSolverMain/ # C++ solver CLI entry point
|
||||
├── tests/ # C++ unit tests + Python tests
|
||||
├── testapp/ # C++ test application
|
||||
├── solver/ # Python ML solver library
|
||||
│ ├── datagen/ # Synthetic data generation (pebble game)
|
||||
│ ├── datasets/ # PyG dataset adapters
|
||||
│ ├── models/ # GNN architectures
|
||||
│ ├── training/ # Training loops
|
||||
│ ├── evaluation/ # Metrics and visualization
|
||||
│ └── inference/ # Runtime prediction API
|
||||
├── freecad/ # FreeCAD workbench integration
|
||||
├── configs/ # Hydra configs (dataset, model, training, export)
|
||||
├── scripts/ # CLI utilities
|
||||
├── data/ # Datasets (not committed)
|
||||
├── export/ # Model packaging
|
||||
└── docs/ # Documentation
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE))
|
||||
ML Solver Layer: Apache-2.0
|
||||
|
||||
12
configs/dataset/fusion360.yaml
Normal file
12
configs/dataset/fusion360.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# Fusion 360 Gallery dataset config
|
||||
name: fusion360
|
||||
data_dir: data/fusion360
|
||||
output_dir: data/processed
|
||||
|
||||
splits:
|
||||
train: 0.8
|
||||
val: 0.1
|
||||
test: 0.1
|
||||
|
||||
stratify_by: complexity
|
||||
seed: 42
|
||||
26
configs/dataset/synthetic.yaml
Normal file
26
configs/dataset/synthetic.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
# Synthetic dataset generation config
|
||||
name: synthetic
|
||||
num_assemblies: 100000
|
||||
output_dir: data/synthetic
|
||||
shard_size: 1000
|
||||
|
||||
complexity_distribution:
|
||||
simple: 0.4 # 2-5 bodies
|
||||
medium: 0.4 # 6-15 bodies
|
||||
complex: 0.2 # 16-50 bodies
|
||||
|
||||
body_count:
|
||||
min: 2
|
||||
max: 50
|
||||
|
||||
templates:
|
||||
- chain
|
||||
- tree
|
||||
- loop
|
||||
- star
|
||||
- mixed
|
||||
|
||||
grounded_ratio: 0.5
|
||||
seed: 42
|
||||
num_workers: 4
|
||||
checkpoint_every: 5
|
||||
25
configs/export/production.yaml
Normal file
25
configs/export/production.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
# Production model export config
|
||||
model_checkpoint: checkpoints/finetune/best_val_loss.ckpt
|
||||
output_dir: export/
|
||||
|
||||
formats:
|
||||
onnx:
|
||||
enabled: true
|
||||
opset_version: 17
|
||||
dynamic_axes: true
|
||||
torchscript:
|
||||
enabled: true
|
||||
|
||||
model_card:
|
||||
version: "0.1.0"
|
||||
architecture: baseline
|
||||
training_data:
|
||||
- synthetic_100k
|
||||
- fusion360_gallery
|
||||
|
||||
size_budget_mb: 50
|
||||
|
||||
inference:
|
||||
device: cpu
|
||||
batch_size: 1
|
||||
confidence_threshold: 0.8
|
||||
24
configs/model/baseline.yaml
Normal file
24
configs/model/baseline.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
# Baseline GIN model config
|
||||
name: baseline
|
||||
architecture: gin
|
||||
|
||||
encoder:
|
||||
num_layers: 3
|
||||
hidden_dim: 128
|
||||
dropout: 0.1
|
||||
|
||||
node_features_dim: 22
|
||||
edge_features_dim: 22
|
||||
|
||||
heads:
|
||||
edge_classification:
|
||||
enabled: true
|
||||
hidden_dim: 64
|
||||
graph_classification:
|
||||
enabled: true
|
||||
num_classes: 4 # rigid, under, over, mixed
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
dof_regression:
|
||||
enabled: true
|
||||
28
configs/model/gat.yaml
Normal file
28
configs/model/gat.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
# Advanced GAT model config
|
||||
name: gat_solver
|
||||
architecture: gat
|
||||
|
||||
encoder:
|
||||
num_layers: 4
|
||||
hidden_dim: 256
|
||||
num_heads: 8
|
||||
dropout: 0.1
|
||||
residual: true
|
||||
|
||||
node_features_dim: 22
|
||||
edge_features_dim: 22
|
||||
|
||||
heads:
|
||||
edge_classification:
|
||||
enabled: true
|
||||
hidden_dim: 128
|
||||
graph_classification:
|
||||
enabled: true
|
||||
num_classes: 4
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
dof_regression:
|
||||
enabled: true
|
||||
dof_tracking:
|
||||
enabled: true
|
||||
45
configs/training/finetune.yaml
Normal file
45
configs/training/finetune.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
# Fine-tuning on real data config
|
||||
phase: finetune
|
||||
|
||||
dataset: fusion360
|
||||
model: baseline
|
||||
|
||||
pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt
|
||||
|
||||
optimizer:
|
||||
name: adamw
|
||||
lr: 1e-5
|
||||
weight_decay: 1e-4
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
T_max: 50
|
||||
eta_min: 1e-7
|
||||
|
||||
training:
|
||||
epochs: 50
|
||||
batch_size: 32
|
||||
gradient_clip: 1.0
|
||||
early_stopping_patience: 10
|
||||
amp: true
|
||||
freeze_encoder: false # set true for frozen encoder experiment
|
||||
|
||||
loss:
|
||||
edge_weight: 1.0
|
||||
graph_weight: 0.5
|
||||
joint_type_weight: 0.3
|
||||
dof_weight: 0.2
|
||||
redundant_penalty: 2.0
|
||||
|
||||
checkpointing:
|
||||
save_best_val_loss: true
|
||||
save_best_val_accuracy: true
|
||||
save_every_n_epochs: 5
|
||||
checkpoint_dir: checkpoints/finetune
|
||||
|
||||
logging:
|
||||
backend: wandb
|
||||
project: kindred-solver
|
||||
log_every_n_steps: 20
|
||||
|
||||
seed: 42
|
||||
42
configs/training/pretrain.yaml
Normal file
42
configs/training/pretrain.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Synthetic pre-training config
|
||||
phase: pretrain
|
||||
|
||||
dataset: synthetic
|
||||
model: baseline
|
||||
|
||||
optimizer:
|
||||
name: adamw
|
||||
lr: 1e-3
|
||||
weight_decay: 1e-4
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
T_max: 100
|
||||
eta_min: 1e-6
|
||||
|
||||
training:
|
||||
epochs: 100
|
||||
batch_size: 64
|
||||
gradient_clip: 1.0
|
||||
early_stopping_patience: 10
|
||||
amp: true
|
||||
|
||||
loss:
|
||||
edge_weight: 1.0
|
||||
graph_weight: 0.5
|
||||
joint_type_weight: 0.3
|
||||
dof_weight: 0.2
|
||||
redundant_penalty: 2.0 # safety loss multiplier
|
||||
|
||||
checkpointing:
|
||||
save_best_val_loss: true
|
||||
save_best_val_accuracy: true
|
||||
save_every_n_epochs: 10
|
||||
checkpoint_dir: checkpoints/pretrain
|
||||
|
||||
logging:
|
||||
backend: wandb # or tensorboard
|
||||
project: kindred-solver
|
||||
log_every_n_steps: 50
|
||||
|
||||
seed: 42
|
||||
0
data/fusion360/.gitkeep
Normal file
0
data/fusion360/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
39
docker-compose.yml
Normal file
39
docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
services:
|
||||
train:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: base
|
||||
volumes:
|
||||
- .:/workspace
|
||||
- ./data:/workspace/data
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
command: make train
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
|
||||
- WANDB_API_KEY=${WANDB_API_KEY:-}
|
||||
|
||||
test:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: cpu-only
|
||||
volumes:
|
||||
- .:/workspace
|
||||
command: make check
|
||||
|
||||
data-gen:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: base
|
||||
volumes:
|
||||
- .:/workspace
|
||||
- ./data:/workspace/data
|
||||
command: make data-gen
|
||||
0
docs/.gitkeep
Normal file
0
docs/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
freecad/__init__.py
Normal file
0
freecad/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
97
pyproject.toml
Normal file
97
pyproject.toml
Normal file
@@ -0,0 +1,97 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "kindred-solver"
|
||||
version = "0.1.0"
|
||||
description = "Assembly constraint prediction via GNN for Kindred Create"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "Kindred Systems" },
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering",
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=2.2",
|
||||
"torch-geometric>=2.5",
|
||||
"numpy>=1.26",
|
||||
"scipy>=1.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
train = [
|
||||
"wandb>=0.16",
|
||||
"tensorboard>=2.16",
|
||||
"hydra-core>=1.3",
|
||||
"omegaconf>=2.3",
|
||||
"matplotlib>=3.8",
|
||||
"networkx>=3.2",
|
||||
]
|
||||
freecad = [
|
||||
"pyside6>=6.6",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-cov>=4.1",
|
||||
"ruff>=0.3",
|
||||
"mypy>=1.8",
|
||||
"pre-commit>=3.6",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://git.kindred-systems.com/kindred/solver"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["solver", "freecad"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"SIM", # flake8-simplify
|
||||
"TCH", # flake8-type-checking
|
||||
"RUF", # ruff-specific
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["solver", "freecad"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"torch.*",
|
||||
"torch_geometric.*",
|
||||
"scipy.*",
|
||||
"wandb.*",
|
||||
"hydra.*",
|
||||
"omegaconf.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "freecad/tests"]
|
||||
addopts = "-v --tb=short"
|
||||
115
scripts/generate_synthetic.py
Normal file
115
scripts/generate_synthetic.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate synthetic assembly dataset for kindred-solver training.
|
||||
|
||||
Usage (argparse fallback — always available)::
|
||||
|
||||
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
||||
|
||||
Usage (Hydra — when hydra-core is installed)::
|
||||
|
||||
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _try_hydra_main() -> bool:
|
||||
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
||||
try:
|
||||
import hydra # type: ignore[import-untyped]
|
||||
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@hydra.main(
|
||||
config_path="../configs/dataset",
|
||||
config_name="synthetic",
|
||||
version_base=None,
|
||||
)
|
||||
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
|
||||
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
_run() # type: ignore[no-untyped-call]
|
||||
return True
|
||||
|
||||
|
||||
def _argparse_main() -> None:
|
||||
"""Fallback CLI using argparse."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic assembly dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to YAML config file (optional)",
|
||||
)
|
||||
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
||||
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
||||
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
||||
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
||||
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
||||
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Checkpoint interval (shards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Do not resume from existing checkpoints",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
config_dict: dict[str, object] = {}
|
||||
if args.config:
|
||||
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
||||
|
||||
# CLI args override config file (only when explicitly provided)
|
||||
_override_map = {
|
||||
"num_assemblies": args.num_assemblies,
|
||||
"output_dir": args.output_dir,
|
||||
"shard_size": args.shard_size,
|
||||
"body_count_min": args.body_count_min,
|
||||
"body_count_max": args.body_count_max,
|
||||
"grounded_ratio": args.grounded_ratio,
|
||||
"seed": args.seed,
|
||||
"num_workers": args.num_workers,
|
||||
"checkpoint_every": args.checkpoint_every,
|
||||
}
|
||||
for key, val in _override_map.items():
|
||||
if val is not None:
|
||||
config_dict[key] = val
|
||||
|
||||
if args.no_resume:
|
||||
config_dict["resume"] = False
|
||||
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point: try Hydra first, fall back to argparse."""
|
||||
if not _try_hydra_main():
|
||||
_argparse_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
solver/__init__.py
Normal file
0
solver/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
from solver.datagen.generator import (
|
||||
COMPLEXITY_RANGES,
|
||||
AxisStrategy,
|
||||
SyntheticAssemblyGenerator,
|
||||
)
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.labeling import AssemblyLabels, label_assembly
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
PebbleState,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AssemblyLabels",
|
||||
"AxisStrategy",
|
||||
"ConstraintAnalysis",
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"JacobianVerifier",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleGame3D",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
"SyntheticAssemblyGenerator",
|
||||
"analyze_assembly",
|
||||
"label_assembly",
|
||||
]
|
||||
140
solver/datagen/analysis.py
Normal file
140
solver/datagen/analysis.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Combined pebble game + Jacobian verification analysis.
|
||||
|
||||
Provides :func:`analyze_assembly`, the main entry point for full rigidity
|
||||
analysis of an assembly using both combinatorial and numerical methods.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
__all__ = ["analyze_assembly"]
|
||||
|
||||
_GROUND_ID = -1
|
||||
|
||||
|
||||
def analyze_assembly(
|
||||
bodies: list[RigidBody],
|
||||
joints: list[Joint],
|
||||
ground_body: int | None = None,
|
||||
) -> ConstraintAnalysis:
|
||||
"""Full rigidity analysis of an assembly using both methods.
|
||||
|
||||
Args:
|
||||
bodies: List of rigid bodies in the assembly.
|
||||
joints: List of joints connecting bodies.
|
||||
ground_body: If set, this body is fixed (adds 6 implicit constraints).
|
||||
|
||||
Returns:
|
||||
ConstraintAnalysis with combinatorial and numerical results.
|
||||
"""
|
||||
# --- Pebble Game ---
|
||||
pg = PebbleGame3D()
|
||||
all_edge_results = []
|
||||
|
||||
# Add a virtual ground body (id=-1) if grounding is requested.
|
||||
# Grounding body X means adding a fixed joint between X and
|
||||
# the virtual ground. This properly lets the pebble game account
|
||||
# for the 6 removed DOF without breaking invariants.
|
||||
if ground_body is not None:
|
||||
pg.add_body(_GROUND_ID)
|
||||
|
||||
for body in bodies:
|
||||
pg.add_body(body.body_id)
|
||||
|
||||
if ground_body is not None:
|
||||
ground_joint = Joint(
|
||||
joint_id=-1,
|
||||
body_a=ground_body,
|
||||
body_b=_GROUND_ID,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
# Don't include ground joint edges in the output labels
|
||||
# (they're infrastructure, not user constraints)
|
||||
|
||||
for joint in joints:
|
||||
results = pg.add_joint(joint)
|
||||
all_edge_results.extend(results)
|
||||
|
||||
combinatorial_independent = len(pg.state.independent_edges)
|
||||
grounded = ground_body is not None
|
||||
|
||||
# The virtual ground body contributes 6 pebbles to the total.
|
||||
# Subtract those from the reported DOF for user-facing numbers.
|
||||
raw_dof = pg.get_dof()
|
||||
ground_offset = 6 if grounded else 0
|
||||
effective_dof = raw_dof - ground_offset
|
||||
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||
|
||||
# Classify based on effective (adjusted) DOF, not raw pebble game output,
|
||||
# because the virtual ground body skews the raw numbers.
|
||||
redundant = pg.get_redundant_count()
|
||||
if redundant > 0 and effective_internal_dof > 0:
|
||||
combinatorial_classification = "mixed"
|
||||
elif redundant > 0:
|
||||
combinatorial_classification = "overconstrained"
|
||||
elif effective_internal_dof > 0:
|
||||
combinatorial_classification = "underconstrained"
|
||||
else:
|
||||
combinatorial_classification = "well-constrained"
|
||||
|
||||
# --- Jacobian Verification ---
|
||||
verifier = JacobianVerifier(bodies)
|
||||
|
||||
for joint in joints:
|
||||
verifier.add_joint_constraints(joint)
|
||||
|
||||
# If grounded, remove the ground body's columns (fix its DOF)
|
||||
j = verifier.get_jacobian()
|
||||
if ground_body is not None and j.size > 0:
|
||||
idx = verifier.body_index[ground_body]
|
||||
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||
j = np.delete(j, cols_to_remove, axis=1)
|
||||
|
||||
if j.size > 0:
|
||||
sv = np.linalg.svd(j, compute_uv=False)
|
||||
jacobian_rank = int(np.sum(sv > 1e-8))
|
||||
else:
|
||||
jacobian_rank = 0
|
||||
|
||||
n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies)
|
||||
jacobian_nullity = n_cols - jacobian_rank
|
||||
|
||||
dependent = verifier.find_dependencies()
|
||||
|
||||
# Adjust for ground
|
||||
trivial_dof = 0 if ground_body is not None else 6
|
||||
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||
|
||||
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||
|
||||
# Rigidity: numerically rigid if nullity == trivial DOF
|
||||
is_rigid = jacobian_nullity <= trivial_dof
|
||||
is_minimally_rigid = is_rigid and len(dependent) == 0
|
||||
|
||||
return ConstraintAnalysis(
|
||||
combinatorial_dof=effective_dof,
|
||||
combinatorial_internal_dof=effective_internal_dof,
|
||||
combinatorial_redundant=pg.get_redundant_count(),
|
||||
combinatorial_classification=combinatorial_classification,
|
||||
per_edge_results=all_edge_results,
|
||||
jacobian_rank=jacobian_rank,
|
||||
jacobian_nullity=jacobian_nullity,
|
||||
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||
numerically_dependent=dependent,
|
||||
geometric_degeneracies=geometric_degeneracies,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
)
|
||||
624
solver/datagen/dataset.py
Normal file
624
solver/datagen/dataset.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
|
||||
|
||||
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
|
||||
for parallel generation of synthetic assembly training data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"parse_simple_yaml",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for synthetic dataset generation."""
|
||||
|
||||
name: str = "synthetic"
|
||||
num_assemblies: int = 100_000
|
||||
output_dir: str = "data/synthetic"
|
||||
shard_size: int = 1000
|
||||
complexity_distribution: dict[str, float] = field(
|
||||
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
|
||||
)
|
||||
body_count_min: int = 2
|
||||
body_count_max: int = 50
|
||||
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
|
||||
grounded_ratio: float = 0.5
|
||||
seed: int = 42
|
||||
num_workers: int = 4
|
||||
checkpoint_every: int = 5
|
||||
resume: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
|
||||
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
|
||||
|
||||
Handles both flat keys (``body_count_min``) and nested forms
|
||||
(``body_count: {min: 2, max: 50}``).
|
||||
"""
|
||||
kw: dict[str, Any] = {}
|
||||
for key in (
|
||||
"name",
|
||||
"num_assemblies",
|
||||
"output_dir",
|
||||
"shard_size",
|
||||
"grounded_ratio",
|
||||
"seed",
|
||||
"num_workers",
|
||||
"checkpoint_every",
|
||||
"resume",
|
||||
):
|
||||
if key in d:
|
||||
kw[key] = d[key]
|
||||
|
||||
# Handle nested body_count dict
|
||||
if "body_count" in d and isinstance(d["body_count"], dict):
|
||||
bc = d["body_count"]
|
||||
if "min" in bc:
|
||||
kw["body_count_min"] = int(bc["min"])
|
||||
if "max" in bc:
|
||||
kw["body_count_max"] = int(bc["max"])
|
||||
else:
|
||||
if "body_count_min" in d:
|
||||
kw["body_count_min"] = int(d["body_count_min"])
|
||||
if "body_count_max" in d:
|
||||
kw["body_count_max"] = int(d["body_count_max"])
|
||||
|
||||
if "complexity_distribution" in d:
|
||||
cd = d["complexity_distribution"]
|
||||
if isinstance(cd, dict):
|
||||
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
|
||||
if "templates" in d and isinstance(d["templates"], list):
|
||||
kw["templates"] = [str(t) for t in d["templates"]]
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard specification / result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardSpec:
|
||||
"""Specification for generating a single shard."""
|
||||
|
||||
shard_id: int
|
||||
start_example_id: int
|
||||
count: int
|
||||
seed: int
|
||||
complexity_distribution: dict[str, float]
|
||||
body_count_min: int
|
||||
body_count_max: int
|
||||
grounded_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardResult:
|
||||
"""Result returned from a shard worker."""
|
||||
|
||||
shard_id: int
|
||||
num_examples: int
|
||||
file_path: str
|
||||
generation_time_s: float
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
|
||||
"""Derive a deterministic per-shard seed from the global seed."""
|
||||
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
|
||||
return int(h[:8], 16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PrintProgress:
|
||||
"""Fallback progress display when tqdm is unavailable."""
|
||||
|
||||
def __init__(self, total: int) -> None:
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
def update(self, n: int = 1) -> None:
|
||||
self.current += n
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
rate = self.current / elapsed if elapsed > 0 else 0.0
|
||||
eta = (self.total - self.current) / rate if rate > 0 else 0.0
|
||||
pct = 100.0 * self.current / self.total
|
||||
sys.stdout.write(
|
||||
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
|
||||
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _make_progress(total: int) -> _PrintProgress:
|
||||
"""Create a progress tracker (tqdm if available, else print-based)."""
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore[import-untyped]
|
||||
|
||||
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
|
||||
except ImportError:
|
||||
return _PrintProgress(total)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _save_shard(
|
||||
shard_id: int,
|
||||
examples: list[dict[str, Any]],
|
||||
shards_dir: Path,
|
||||
) -> Path:
|
||||
"""Save a shard to disk (.pt if torch available, else .json)."""
|
||||
shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
path = shards_dir / f"shard_{shard_id:05d}.pt"
|
||||
torch.save(examples, path)
|
||||
except ImportError:
|
||||
path = shards_dir / f"shard_{shard_id:05d}.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(examples, f)
|
||||
return path
|
||||
|
||||
|
||||
def _load_shard(path: Path) -> list[dict[str, Any]]:
|
||||
"""Load a shard from disk (.pt or .json)."""
|
||||
if path.suffix == ".pt":
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
|
||||
return result
|
||||
with open(path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
|
||||
def _shard_format() -> str:
|
||||
"""Return the shard file extension based on available libraries."""
|
||||
try:
|
||||
import torch # type: ignore[import-untyped] # noqa: F401
|
||||
|
||||
return ".pt"
|
||||
except ImportError:
|
||||
return ".json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard worker (module-level for pickling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
|
||||
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
|
||||
t0 = time.monotonic()
|
||||
gen = SyntheticAssemblyGenerator(seed=spec.seed)
|
||||
rng = np.random.default_rng(spec.seed + 1)
|
||||
|
||||
tiers = list(spec.complexity_distribution.keys())
|
||||
probs_list = [spec.complexity_distribution[t] for t in tiers]
|
||||
total = sum(probs_list)
|
||||
probs = [p / total for p in probs_list]
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
for i in range(spec.count):
|
||||
tier_idx = int(rng.choice(len(tiers), p=probs))
|
||||
tier = tiers[tier_idx]
|
||||
try:
|
||||
batch = gen.generate_training_batch(
|
||||
batch_size=1,
|
||||
complexity_tier=tier, # type: ignore[arg-type]
|
||||
grounded_ratio=spec.grounded_ratio,
|
||||
)
|
||||
ex = batch[0]
|
||||
ex["example_id"] = spec.start_example_id + i
|
||||
ex["complexity_tier"] = tier
|
||||
examples.append(ex)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Shard %d, example %d failed — skipping",
|
||||
spec.shard_id,
|
||||
i,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
shards_dir = Path(output_dir) / "shards"
|
||||
path = _save_shard(spec.shard_id, examples, shards_dir)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
return ShardResult(
|
||||
shard_id=spec.shard_id,
|
||||
num_examples=len(examples),
|
||||
file_path=str(path),
|
||||
generation_time_s=elapsed,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_scalar(value: str) -> int | float | bool | str:
|
||||
"""Parse a YAML scalar value."""
|
||||
# Strip inline comments (space + #)
|
||||
if " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
elif " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
v = value.strip()
|
||||
if v.lower() in ("true", "yes"):
|
||||
return True
|
||||
if v.lower() in ("false", "no"):
|
||||
return False
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(v)
|
||||
except ValueError:
|
||||
pass
|
||||
return v.strip("'\"")
|
||||
|
||||
|
||||
def parse_simple_yaml(path: str) -> dict[str, Any]:
|
||||
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
|
||||
|
||||
This is **not** a full YAML parser. It handles the structure of
|
||||
``configs/dataset/synthetic.yaml``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
current_key: str | None = None
|
||||
|
||||
with open(path) as f:
|
||||
for raw_line in f:
|
||||
line = raw_line.rstrip()
|
||||
|
||||
# Skip blank lines and full-line comments
|
||||
if not line or line.lstrip().startswith("#"):
|
||||
continue
|
||||
|
||||
indent = len(line) - len(line.lstrip())
|
||||
|
||||
if indent == 0 and ":" in line:
|
||||
key, _, value = line.partition(":")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if value:
|
||||
result[key] = _parse_scalar(value)
|
||||
current_key = None
|
||||
else:
|
||||
current_key = key
|
||||
result[key] = {}
|
||||
continue
|
||||
|
||||
if indent > 0 and line.lstrip().startswith("- "):
|
||||
item = line.lstrip()[2:].strip()
|
||||
if current_key is not None:
|
||||
if isinstance(result.get(current_key), dict) and not result[current_key]:
|
||||
result[current_key] = []
|
||||
if isinstance(result.get(current_key), list):
|
||||
result[current_key].append(_parse_scalar(item))
|
||||
continue
|
||||
|
||||
if indent > 0 and ":" in line and current_key is not None:
|
||||
k, _, v = line.partition(":")
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if v:
|
||||
# Strip inline comments
|
||||
if " #" in v:
|
||||
v = v[: v.index(" #")].strip()
|
||||
if not isinstance(result.get(current_key), dict):
|
||||
result[current_key] = {}
|
||||
result[current_key][k] = _parse_scalar(v)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset generator orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
|
||||
|
||||
def __init__(self, config: DatasetConfig) -> None:
|
||||
self.config = config
|
||||
self.output_path = Path(config.output_dir)
|
||||
self.shards_dir = self.output_path / "shards"
|
||||
self.checkpoint_file = self.output_path / ".checkpoint.json"
|
||||
self.index_file = self.output_path / "index.json"
|
||||
self.stats_file = self.output_path / "stats.json"
|
||||
|
||||
# -- public API --
|
||||
|
||||
def run(self) -> None:
|
||||
"""Generate the full dataset."""
|
||||
self.output_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = self._plan_shards()
|
||||
total_shards = len(shards)
|
||||
|
||||
# Resume: find already-completed shards
|
||||
completed: set[int] = set()
|
||||
if self.config.resume:
|
||||
completed = self._find_completed_shards()
|
||||
|
||||
pending = [s for s in shards if s.shard_id not in completed]
|
||||
|
||||
if not pending:
|
||||
logger.info("All %d shards already complete.", total_shards)
|
||||
else:
|
||||
logger.info(
|
||||
"Generating %d shards (%d already complete).",
|
||||
len(pending),
|
||||
len(completed),
|
||||
)
|
||||
progress = _make_progress(len(pending))
|
||||
workers = max(1, self.config.num_workers)
|
||||
checkpoint_counter = 0
|
||||
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
_generate_shard_worker,
|
||||
spec,
|
||||
str(self.output_path),
|
||||
): spec.shard_id
|
||||
for spec in pending
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
shard_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
completed.add(result.shard_id)
|
||||
logger.debug(
|
||||
"Shard %d: %d examples in %.1fs",
|
||||
result.shard_id,
|
||||
result.num_examples,
|
||||
result.generation_time_s,
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Shard %d failed", shard_id, exc_info=True)
|
||||
progress.update(1)
|
||||
checkpoint_counter += 1
|
||||
if checkpoint_counter >= self.config.checkpoint_every:
|
||||
self._update_checkpoint(completed, total_shards)
|
||||
checkpoint_counter = 0
|
||||
|
||||
progress.close()
|
||||
|
||||
# Finalize
|
||||
self._build_index()
|
||||
stats = self._compute_statistics()
|
||||
self._write_statistics(stats)
|
||||
self._print_summary(stats)
|
||||
|
||||
# Remove checkpoint (generation complete)
|
||||
if self.checkpoint_file.exists():
|
||||
self.checkpoint_file.unlink()
|
||||
|
||||
# -- internal helpers --
|
||||
|
||||
def _plan_shards(self) -> list[ShardSpec]:
|
||||
"""Divide num_assemblies into shards."""
|
||||
n = self.config.num_assemblies
|
||||
size = self.config.shard_size
|
||||
num_shards = math.ceil(n / size)
|
||||
shards: list[ShardSpec] = []
|
||||
for i in range(num_shards):
|
||||
start = i * size
|
||||
count = min(size, n - start)
|
||||
shards.append(
|
||||
ShardSpec(
|
||||
shard_id=i,
|
||||
start_example_id=start,
|
||||
count=count,
|
||||
seed=_derive_shard_seed(self.config.seed, i),
|
||||
complexity_distribution=dict(self.config.complexity_distribution),
|
||||
body_count_min=self.config.body_count_min,
|
||||
body_count_max=self.config.body_count_max,
|
||||
grounded_ratio=self.config.grounded_ratio,
|
||||
)
|
||||
)
|
||||
return shards
|
||||
|
||||
def _find_completed_shards(self) -> set[int]:
|
||||
"""Scan shards directory for existing shard files."""
|
||||
completed: set[int] = set()
|
||||
if not self.shards_dir.exists():
|
||||
return completed
|
||||
|
||||
for p in self.shards_dir.iterdir():
|
||||
if p.stem.startswith("shard_"):
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
# Verify file is non-empty
|
||||
if p.stat().st_size > 0:
|
||||
completed.add(shard_id)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return completed
|
||||
|
||||
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
|
||||
"""Write checkpoint file."""
|
||||
data = {
|
||||
"completed_shards": sorted(completed),
|
||||
"total_shards": total_shards,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
with open(self.checkpoint_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _build_index(self) -> None:
|
||||
"""Build index.json mapping shard files to assembly ID ranges."""
|
||||
shards_info: dict[str, dict[str, int]] = {}
|
||||
total_assemblies = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
count = len(examples)
|
||||
start_id = shard_id * self.config.shard_size
|
||||
shards_info[p.name] = {"start_id": start_id, "count": count}
|
||||
total_assemblies += count
|
||||
|
||||
fmt = _shard_format().lstrip(".")
|
||||
index = {
|
||||
"format_version": 1,
|
||||
"total_assemblies": total_assemblies,
|
||||
"total_shards": len(shards_info),
|
||||
"shard_format": fmt,
|
||||
"shards": shards_info,
|
||||
}
|
||||
with open(self.index_file, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
def _compute_statistics(self) -> dict[str, Any]:
|
||||
"""Aggregate statistics across all shards."""
|
||||
classification_counts: dict[str, int] = {}
|
||||
body_count_hist: dict[int, int] = {}
|
||||
joint_type_counts: dict[str, int] = {}
|
||||
dof_values: list[int] = []
|
||||
degeneracy_values: list[int] = []
|
||||
rigid_count = 0
|
||||
minimally_rigid_count = 0
|
||||
total = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
for ex in examples:
|
||||
total += 1
|
||||
cls = str(ex.get("assembly_classification", "unknown"))
|
||||
classification_counts[cls] = classification_counts.get(cls, 0) + 1
|
||||
nb = int(ex.get("n_bodies", 0))
|
||||
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
|
||||
for j in ex.get("joints", []):
|
||||
jt = str(j.get("type", "unknown"))
|
||||
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
|
||||
dof_values.append(int(ex.get("internal_dof", 0)))
|
||||
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
|
||||
if ex.get("is_rigid"):
|
||||
rigid_count += 1
|
||||
if ex.get("is_minimally_rigid"):
|
||||
minimally_rigid_count += 1
|
||||
|
||||
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
|
||||
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
|
||||
|
||||
return {
|
||||
"total_examples": total,
|
||||
"classification_distribution": dict(sorted(classification_counts.items())),
|
||||
"body_count_histogram": dict(sorted(body_count_hist.items())),
|
||||
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
|
||||
"dof_statistics": {
|
||||
"mean": float(dof_arr.mean()),
|
||||
"std": float(dof_arr.std()),
|
||||
"min": int(dof_arr.min()),
|
||||
"max": int(dof_arr.max()),
|
||||
"median": float(np.median(dof_arr)),
|
||||
},
|
||||
"geometric_degeneracy": {
|
||||
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
|
||||
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
|
||||
"mean_degeneracies": float(deg_arr.mean()),
|
||||
},
|
||||
"rigidity": {
|
||||
"rigid_count": rigid_count,
|
||||
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
|
||||
"minimally_rigid_count": minimally_rigid_count,
|
||||
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
def _write_statistics(self, stats: dict[str, Any]) -> None:
|
||||
"""Write stats.json."""
|
||||
with open(self.stats_file, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
def _print_summary(self, stats: dict[str, Any]) -> None:
|
||||
"""Print a human-readable summary to stdout."""
|
||||
print("\n=== Dataset Generation Summary ===")
|
||||
print(f"Total examples: {stats['total_examples']}")
|
||||
print(f"Output directory: {self.output_path}")
|
||||
print()
|
||||
print("Classification distribution:")
|
||||
for cls, count in stats["classification_distribution"].items():
|
||||
frac = count / max(stats["total_examples"], 1) * 100
|
||||
print(f" {cls}: {count} ({frac:.1f}%)")
|
||||
print()
|
||||
print("Joint type distribution:")
|
||||
for jt, count in stats["joint_type_distribution"].items():
|
||||
print(f" {jt}: {count}")
|
||||
print()
|
||||
dof = stats["dof_statistics"]
|
||||
print(
|
||||
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
|
||||
)
|
||||
rig = stats["rigidity"]
|
||||
print(
|
||||
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
|
||||
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
|
||||
f"{rig['minimally_rigid_count']} minimally rigid"
|
||||
)
|
||||
deg = stats["geometric_degeneracy"]
|
||||
print(
|
||||
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
|
||||
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
|
||||
)
|
||||
893
solver/datagen/generator.py
Normal file
893
solver/datagen/generator.py
Normal file
@@ -0,0 +1,893 @@
|
||||
"""Synthetic assembly graph generator for training data production.
|
||||
|
||||
Generates assembly graphs with known constraint classifications using
|
||||
the pebble game and Jacobian verification. Each assembly is fully labeled
|
||||
with per-constraint independence flags and assembly-level classification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.labeling import label_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AxisStrategy",
|
||||
"ComplexityTier",
|
||||
"SyntheticAssemblyGenerator",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ComplexityTier = Literal["simple", "medium", "complex"]
|
||||
|
||||
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
|
||||
"simple": (2, 6),
|
||||
"medium": (6, 16),
|
||||
"complex": (16, 51),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
|
||||
|
||||
|
||||
class SyntheticAssemblyGenerator:
|
||||
"""Generates assembly graphs with known minimal constraint sets.
|
||||
|
||||
Uses the pebble game to incrementally build assemblies, tracking
|
||||
exactly which constraints are independent at each step. This produces
|
||||
labeled training data: (assembly_graph, constraint_set, labels).
|
||||
|
||||
Labels per constraint:
|
||||
- independent: bool (does this constraint remove a DOF?)
|
||||
- redundant: bool (is this constraint overconstrained?)
|
||||
- minimal_set: bool (part of a minimal rigidity basis?)
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _random_position(self, scale: float = 5.0) -> np.ndarray:
|
||||
"""Generate random 3D position within [-scale, scale] cube."""
|
||||
return self.rng.uniform(-scale, scale, size=3)
|
||||
|
||||
def _random_axis(self) -> np.ndarray:
|
||||
"""Generate random normalized 3D axis."""
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
return axis
|
||||
|
||||
def _random_orientation(self) -> np.ndarray:
|
||||
"""Generate a random 3x3 rotation matrix."""
|
||||
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
|
||||
return mat
|
||||
|
||||
def _cardinal_axis(self) -> np.ndarray:
|
||||
"""Pick uniformly from the six signed cardinal directions."""
|
||||
axes = np.array(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
result: np.ndarray = axes[int(self.rng.integers(6))]
|
||||
return result
|
||||
|
||||
def _near_parallel_axis(
|
||||
self,
|
||||
base_axis: np.ndarray,
|
||||
perturbation_scale: float = 0.05,
|
||||
) -> np.ndarray:
|
||||
"""Return *base_axis* with a small random perturbation, re-normalized."""
|
||||
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
|
||||
return perturbed / np.linalg.norm(perturbed)
|
||||
|
||||
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
|
||||
"""Sample a joint axis using the specified strategy."""
|
||||
if strategy == "cardinal":
|
||||
return self._cardinal_axis()
|
||||
if strategy == "near_parallel":
|
||||
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
|
||||
return self._random_axis()
|
||||
|
||||
def _resolve_axis(
|
||||
self,
|
||||
strategy: AxisStrategy,
|
||||
parallel_axis_prob: float,
|
||||
shared_axis: np.ndarray | None,
|
||||
) -> tuple[np.ndarray, np.ndarray | None]:
|
||||
"""Return (axis_for_this_joint, shared_axis_to_propagate).
|
||||
|
||||
On the first call where *shared_axis* is ``None`` and parallel
|
||||
injection triggers, a base axis is chosen and returned as
|
||||
*shared_axis* for subsequent calls.
|
||||
"""
|
||||
if shared_axis is not None:
|
||||
return self._near_parallel_axis(shared_axis), shared_axis
|
||||
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
|
||||
base = self._sample_axis(strategy)
|
||||
return base.copy(), base
|
||||
return self._sample_axis(strategy), None
|
||||
|
||||
def _select_joint_type(
|
||||
self,
|
||||
joint_types: JointType | list[JointType],
|
||||
) -> JointType:
|
||||
"""Select a joint type from a single type or list."""
|
||||
if isinstance(joint_types, list):
|
||||
idx = int(self.rng.integers(0, len(joint_types)))
|
||||
return joint_types[idx]
|
||||
return joint_types
|
||||
|
||||
def _create_joint(
|
||||
self,
|
||||
joint_id: int,
|
||||
body_a_id: int,
|
||||
body_b_id: int,
|
||||
pos_a: np.ndarray,
|
||||
pos_b: np.ndarray,
|
||||
joint_type: JointType,
|
||||
*,
|
||||
axis: np.ndarray | None = None,
|
||||
orient_a: np.ndarray | None = None,
|
||||
orient_b: np.ndarray | None = None,
|
||||
) -> Joint:
|
||||
"""Create a joint between two bodies.
|
||||
|
||||
When orientations are provided, anchor points are offset from
|
||||
each body's center along a random local direction rotated into
|
||||
world frame, rather than placed at the midpoint.
|
||||
"""
|
||||
if orient_a is not None and orient_b is not None:
|
||||
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
|
||||
offset_scale = dist * 0.2
|
||||
local_a = self.rng.standard_normal(3) * offset_scale
|
||||
local_b = self.rng.standard_normal(3) * offset_scale
|
||||
anchor_a = pos_a + orient_a @ local_a
|
||||
anchor_b = pos_b + orient_b @ local_b
|
||||
else:
|
||||
anchor = (pos_a + pos_b) / 2.0
|
||||
anchor_a = anchor
|
||||
anchor_b = anchor
|
||||
|
||||
return Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=body_a_id,
|
||||
body_b=body_b_id,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor_a,
|
||||
anchor_b=anchor_b,
|
||||
axis=axis if axis is not None else self._random_axis(),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_chain_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_type: JointType = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a serial kinematic chain.
|
||||
|
||||
Simple but useful: each body connects to the next with the
|
||||
specified joint type. Results in an underconstrained assembly
|
||||
(serial chain is never rigid without closing loops).
|
||||
"""
|
||||
bodies = []
|
||||
joints = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
pos = np.array([i * 2.0, 0.0, 0.0])
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=pos,
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies - 1):
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
i + 1,
|
||||
bodies[i].position,
|
||||
bodies[i + 1].position,
|
||||
joint_type,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[i + 1].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_rigid_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a minimally rigid assembly by adding joints until rigid.
|
||||
|
||||
Strategy: start with fixed joints on a spanning tree (guarantees
|
||||
rigidity), then randomly relax some to weaker joint types while
|
||||
maintaining rigidity via the pebble game check.
|
||||
"""
|
||||
bodies = []
|
||||
for i in range(n_bodies):
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Build spanning tree with fixed joints (overconstrained)
|
||||
joints: list[Joint] = []
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
JointType.FIXED,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
# Try relaxing joints to weaker types while maintaining rigidity
|
||||
weaker_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.BALL,
|
||||
]
|
||||
|
||||
ground = 0 if grounded else None
|
||||
for idx in self.rng.permutation(len(joints)):
|
||||
original_type = joints[idx].joint_type
|
||||
for candidate in weaker_types:
|
||||
joints[idx].joint_type = candidate
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
if analysis.is_rigid:
|
||||
break # Keep the weaker type
|
||||
else:
|
||||
joints[idx].joint_type = original_type
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_overconstrained_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
extra_joints: int = 2,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate an assembly with known redundant constraints.
|
||||
|
||||
Starts with a rigid assembly, then adds extra joints that
|
||||
the pebble game will flag as redundant.
|
||||
"""
|
||||
bodies, joints, _ = self.generate_rigid_assembly(
|
||||
n_bodies,
|
||||
grounded=grounded,
|
||||
axis_strategy=axis_strategy,
|
||||
parallel_axis_prob=parallel_axis_prob,
|
||||
)
|
||||
|
||||
joint_id = len(joints)
|
||||
shared_axis: np.ndarray | None = None
|
||||
for _ in range(extra_joints):
|
||||
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||
_overcon_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.FIXED,
|
||||
JointType.BALL,
|
||||
]
|
||||
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
int(a),
|
||||
int(b),
|
||||
bodies[int(a)].position,
|
||||
bodies[int(b)].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[int(a)].orientation,
|
||||
orient_b=bodies[int(b)].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
ground = 0 if grounded else None
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_tree_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
branching_factor: int = 3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a random tree topology with configurable branching.
|
||||
|
||||
Creates a tree where each body can have up to *branching_factor*
|
||||
children. Different branches can use different joint types if a
|
||||
list is provided. Always underconstrained (no closed loops).
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies (root + children).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
branching_factor: Max children per parent (1-5 recommended).
|
||||
"""
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
available_parents = [0]
|
||||
next_id = 1
|
||||
joint_id = 0
|
||||
shared_axis: np.ndarray | None = None
|
||||
|
||||
while next_id < n_bodies and available_parents:
|
||||
pidx = int(self.rng.integers(0, len(available_parents)))
|
||||
parent_id = available_parents[pidx]
|
||||
parent_pos = bodies[parent_id].position
|
||||
|
||||
max_children = min(branching_factor, n_bodies - next_id)
|
||||
n_children = int(self.rng.integers(1, max_children + 1))
|
||||
|
||||
for _ in range(n_children):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(1.5, 3.0)
|
||||
child_pos = parent_pos + direction * distance
|
||||
child_orient = self._random_orientation()
|
||||
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=next_id,
|
||||
position=child_pos,
|
||||
orientation=child_orient,
|
||||
)
|
||||
)
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent_id,
|
||||
next_id,
|
||||
parent_pos,
|
||||
child_pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent_id].orientation,
|
||||
orient_b=child_orient,
|
||||
)
|
||||
)
|
||||
|
||||
available_parents.append(next_id)
|
||||
next_id += 1
|
||||
joint_id += 1
|
||||
if next_id >= n_bodies:
|
||||
break
|
||||
|
||||
# Retire parent if it reached branching limit or randomly
|
||||
if n_children >= branching_factor or self.rng.random() < 0.3:
|
||||
available_parents.pop(pidx)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_loop_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a single closed loop (ring) of bodies.
|
||||
|
||||
The closing constraint introduces redundancy, making this
|
||||
useful for generating overconstrained training data.
|
||||
|
||||
Args:
|
||||
n_bodies: Bodies in the ring (>= 3).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 3.
|
||||
"""
|
||||
if n_bodies < 3:
|
||||
msg = "Loop assembly requires at least 3 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
base_radius = max(2.0, n_bodies * 0.4)
|
||||
for i in range(n_bodies):
|
||||
angle = 2 * np.pi * i / n_bodies
|
||||
radius = base_radius + self.rng.uniform(-0.5, 0.5)
|
||||
x = radius * np.cos(angle)
|
||||
y = radius * np.sin(angle)
|
||||
z = float(self.rng.uniform(-0.2, 0.2))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=np.array([x, y, z]),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies):
|
||||
next_i = (i + 1) % n_bodies
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
next_i,
|
||||
bodies[i].position,
|
||||
bodies[next_i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[next_i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_star_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a star topology with central hub and satellites.
|
||||
|
||||
Body 0 is the hub; all other bodies connect directly to it.
|
||||
Underconstrained because there are no inter-satellite connections.
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies including hub (>= 2).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 2.
|
||||
"""
|
||||
if n_bodies < 2:
|
||||
msg = "Star assembly requires at least 2 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
hub_orient = self._random_orientation()
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=hub_orient,
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(2.0, 5.0)
|
||||
pos = direction * distance
|
||||
sat_orient = self._random_orientation()
|
||||
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
|
||||
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
0,
|
||||
i,
|
||||
np.zeros(3),
|
||||
pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=hub_orient,
|
||||
orient_b=sat_orient,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_mixed_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
edge_density: float = 0.3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a mixed topology combining tree and loop elements.
|
||||
|
||||
Builds a spanning tree for connectivity, then adds extra edges
|
||||
based on *edge_density* to create loops and redundancy.
|
||||
|
||||
Args:
|
||||
n_bodies: Number of bodies.
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
edge_density: Fraction of non-tree edges to add (0.0-1.0).
|
||||
|
||||
Raises:
|
||||
ValueError: If *edge_density* not in [0.0, 1.0].
|
||||
"""
|
||||
if not 0.0 <= edge_density <= 1.0:
|
||||
msg = "edge_density must be in [0.0, 1.0]"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 1: spanning tree
|
||||
joint_id = 0
|
||||
existing_edges: set[frozenset[int]] = set()
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
existing_edges.add(frozenset([parent, i]))
|
||||
joint_id += 1
|
||||
|
||||
# Phase 2: extra edges based on density
|
||||
candidates: list[tuple[int, int]] = []
|
||||
for i in range(n_bodies):
|
||||
for j in range(i + 1, n_bodies):
|
||||
if frozenset([i, j]) not in existing_edges:
|
||||
candidates.append((i, j))
|
||||
|
||||
n_extra = int(edge_density * len(candidates))
|
||||
self.rng.shuffle(candidates)
|
||||
|
||||
for a, b in candidates[:n_extra]:
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
a,
|
||||
b,
|
||||
bodies[a].position,
|
||||
bodies[b].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[a].orientation,
|
||||
orient_b=bodies[b].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_training_batch(
|
||||
self,
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] | None = None,
|
||||
complexity_tier: ComplexityTier | None = None,
|
||||
*,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
grounded_ratio: float = 1.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a batch of labeled training examples.
|
||||
|
||||
Each example contains body positions, joint descriptions,
|
||||
per-joint independence labels, and assembly-level classification.
|
||||
|
||||
Args:
|
||||
batch_size: Number of assemblies to generate.
|
||||
n_bodies_range: ``(min, max_exclusive)`` body count.
|
||||
Overridden by *complexity_tier* when both are given.
|
||||
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
|
||||
/ ``"complex"``). Overrides *n_bodies_range*.
|
||||
axis_strategy: Axis sampling strategy for joint axes.
|
||||
parallel_axis_prob: Probability of parallel axis injection.
|
||||
grounded_ratio: Fraction of examples that are grounded.
|
||||
"""
|
||||
if complexity_tier is not None:
|
||||
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
|
||||
elif n_bodies_range is None:
|
||||
n_bodies_range = (3, 8)
|
||||
|
||||
_joint_pool = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.FIXED,
|
||||
]
|
||||
|
||||
geo_kw: dict[str, Any] = {
|
||||
"axis_strategy": axis_strategy,
|
||||
"parallel_axis_prob": parallel_axis_prob,
|
||||
}
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
n = int(self.rng.integers(*n_bodies_range))
|
||||
gen_idx = int(self.rng.integers(7))
|
||||
grounded = bool(self.rng.random() < grounded_ratio)
|
||||
|
||||
if gen_idx == 0:
|
||||
_chain_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
]
|
||||
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||
bodies, joints, analysis = self.generate_chain_assembly(
|
||||
n,
|
||||
jtype,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "chain"
|
||||
elif gen_idx == 1:
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(
|
||||
n,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "rigid"
|
||||
elif gen_idx == 2:
|
||||
extra = int(self.rng.integers(1, 4))
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(
|
||||
n,
|
||||
extra,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "overconstrained"
|
||||
elif gen_idx == 3:
|
||||
branching = int(self.rng.integers(2, 5))
|
||||
bodies, joints, analysis = self.generate_tree_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
branching,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "tree"
|
||||
elif gen_idx == 4:
|
||||
n = max(n, 3)
|
||||
bodies, joints, analysis = self.generate_loop_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "loop"
|
||||
elif gen_idx == 5:
|
||||
n = max(n, 2)
|
||||
bodies, joints, analysis = self.generate_star_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "star"
|
||||
else:
|
||||
density = float(self.rng.uniform(0.2, 0.5))
|
||||
bodies, joints, analysis = self.generate_mixed_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
density,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "mixed"
|
||||
|
||||
# Produce ground truth labels (includes ConstraintAnalysis)
|
||||
ground = 0 if grounded else None
|
||||
labels = label_assembly(bodies, joints, ground_body=ground)
|
||||
analysis = labels.analysis
|
||||
|
||||
# Build per-joint labels from edge results
|
||||
joint_labels: dict[int, dict[str, int]] = {}
|
||||
for result in analysis.per_edge_results:
|
||||
jid = result["joint_id"]
|
||||
if jid not in joint_labels:
|
||||
joint_labels[jid] = {
|
||||
"independent_constraints": 0,
|
||||
"redundant_constraints": 0,
|
||||
"total_constraints": 0,
|
||||
}
|
||||
joint_labels[jid]["total_constraints"] += 1
|
||||
if result["independent"]:
|
||||
joint_labels[jid]["independent_constraints"] += 1
|
||||
else:
|
||||
joint_labels[jid]["redundant_constraints"] += 1
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"example_id": i,
|
||||
"generator_type": gen_name,
|
||||
"grounded": grounded,
|
||||
"n_bodies": len(bodies),
|
||||
"n_joints": len(joints),
|
||||
"body_positions": [b.position.tolist() for b in bodies],
|
||||
"body_orientations": [b.orientation.tolist() for b in bodies],
|
||||
"joints": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"body_a": j.body_a,
|
||||
"body_b": j.body_b,
|
||||
"type": j.joint_type.name,
|
||||
"axis": j.axis.tolist(),
|
||||
}
|
||||
for j in joints
|
||||
],
|
||||
"joint_labels": joint_labels,
|
||||
"labels": labels.to_dict(),
|
||||
"assembly_classification": (analysis.combinatorial_classification),
|
||||
"is_rigid": analysis.is_rigid,
|
||||
"is_minimally_rigid": analysis.is_minimally_rigid,
|
||||
"internal_dof": analysis.jacobian_internal_dof,
|
||||
"geometric_degeneracies": (analysis.geometric_degeneracies),
|
||||
}
|
||||
)
|
||||
|
||||
return examples
|
||||
517
solver/datagen/jacobian.py
Normal file
517
solver/datagen/jacobian.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Numerical Jacobian rank verification for assembly constraint analysis.
|
||||
|
||||
Builds the constraint Jacobian matrix and analyzes its numerical rank
|
||||
to detect geometric degeneracies that the combinatorial pebble game
|
||||
cannot identify (e.g., parallel revolute axes creating hidden dependencies).
|
||||
|
||||
References:
|
||||
- Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["JacobianVerifier"]
|
||||
|
||||
|
||||
class JacobianVerifier:
|
||||
"""Builds and analyzes the constraint Jacobian for numerical rank check.
|
||||
|
||||
The pebble game gives a combinatorial *necessary* condition for
|
||||
rigidity. However, geometric special cases (e.g., all revolute axes
|
||||
parallel, creating a hidden dependency) require numerical verification.
|
||||
|
||||
For each joint, we construct the constraint Jacobian rows that map
|
||||
the 6n-dimensional generalized velocity vector to the constraint
|
||||
violation rates. The rank of this Jacobian equals the number of
|
||||
truly independent constraints.
|
||||
|
||||
The generalized velocity vector for n bodies is::
|
||||
|
||||
v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z]
|
||||
|
||||
Each scalar constraint C_i contributes one row to J such that::
|
||||
|
||||
dC_i/dt = J_i @ v = 0
|
||||
"""
|
||||
|
||||
def __init__(self, bodies: list[RigidBody]) -> None:
|
||||
self.bodies = {b.body_id: b for b in bodies}
|
||||
self.body_index = {b.body_id: i for i, b in enumerate(bodies)}
|
||||
self.n_bodies = len(bodies)
|
||||
self.jacobian_rows: list[np.ndarray] = []
|
||||
self.row_labels: list[dict[str, Any]] = []
|
||||
|
||||
def _body_cols(self, body_id: int) -> tuple[int, int]:
|
||||
"""Return the column range [start, end) for a body in J."""
|
||||
idx = self.body_index[body_id]
|
||||
return idx * 6, (idx + 1) * 6
|
||||
|
||||
def add_joint_constraints(self, joint: Joint) -> int:
|
||||
"""Add Jacobian rows for all scalar constraints of a joint.
|
||||
|
||||
Returns the number of rows added.
|
||||
"""
|
||||
builder = {
|
||||
JointType.FIXED: self._build_fixed,
|
||||
JointType.REVOLUTE: self._build_revolute,
|
||||
JointType.CYLINDRICAL: self._build_cylindrical,
|
||||
JointType.SLIDER: self._build_slider,
|
||||
JointType.BALL: self._build_ball,
|
||||
JointType.PLANAR: self._build_planar,
|
||||
JointType.DISTANCE: self._build_distance,
|
||||
JointType.PARALLEL: self._build_parallel,
|
||||
JointType.PERPENDICULAR: self._build_perpendicular,
|
||||
JointType.UNIVERSAL: self._build_universal,
|
||||
JointType.SCREW: self._build_screw,
|
||||
}
|
||||
|
||||
rows_before = len(self.jacobian_rows)
|
||||
builder[joint.joint_type](joint)
|
||||
return len(self.jacobian_rows) - rows_before
|
||||
|
||||
def _make_row(self) -> np.ndarray:
|
||||
"""Create a zero row of width 6*n_bodies."""
|
||||
return np.zeros(6 * self.n_bodies)
|
||||
|
||||
def _skew(self, v: np.ndarray) -> np.ndarray:
|
||||
"""Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``."""
|
||||
return np.array(
|
||||
[
|
||||
[0, -v[2], v[1]],
|
||||
[v[2], 0, -v[0]],
|
||||
[-v[1], v[0], 0],
|
||||
]
|
||||
)
|
||||
|
||||
# --- Ball-and-socket (spherical) joint: 3 translation constraints ---
|
||||
|
||||
def _build_ball(self, joint: Joint) -> None:
|
||||
"""Ball joint: coincident point constraint.
|
||||
|
||||
``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0``
|
||||
(3 equations)
|
||||
|
||||
Jacobian rows (for each of x, y, z):
|
||||
body_a linear: -I
|
||||
body_a angular: +skew(R_a @ r_a)
|
||||
body_b linear: +I
|
||||
body_b angular: -skew(R_b @ r_b)
|
||||
"""
|
||||
# Use anchor positions directly as world-frame offsets
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
e = np.zeros(3)
|
||||
e[axis_idx] = 1.0
|
||||
|
||||
row[col_a_start : col_a_start + 3] = -e
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, e)
|
||||
|
||||
row[col_b_start : col_b_start + 3] = e
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "ball_translation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Fixed joint: 3 translation + 3 rotation constraints ---
|
||||
|
||||
def _build_fixed(self, joint: Joint) -> None:
|
||||
"""Fixed joint = ball joint + locked rotation.
|
||||
|
||||
Translation part: same as ball joint (3 rows).
|
||||
Rotation part: relative angular velocity must be zero (3 rows).
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "fixed_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Revolute (hinge) joint: 3 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_revolute(self, joint: Joint) -> None:
|
||||
"""Revolute joint: rotation only about one axis.
|
||||
|
||||
Translation: same as ball (3 rows).
|
||||
Rotation: relative angular velocity must be parallel to hinge axis.
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "revolute_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Cylindrical joint: 2 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_cylindrical(self, joint: Joint) -> None:
|
||||
"""Cylindrical joint: allows rotation + translation along one axis.
|
||||
|
||||
Translation: constrain motion perpendicular to axis (2 rows).
|
||||
Rotation: constrain rotation perpendicular to axis (2 rows).
|
||||
"""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -t
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||
row[col_b_start : col_b_start + 3] = t
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "cylindrical_translation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "cylindrical_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Slider (prismatic) joint: 2 translation + 3 rotation constraints ---
|
||||
|
||||
def _build_slider(self, joint: Joint) -> None:
|
||||
"""Slider/prismatic joint: translation along one axis only.
|
||||
|
||||
Translation: perpendicular translation constrained (2 rows).
|
||||
Rotation: all relative rotation constrained (3 rows).
|
||||
"""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -t
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||
row[col_b_start : col_b_start + 3] = t
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "slider_translation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "slider_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Planar joint: 1 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_planar(self, joint: Joint) -> None:
|
||||
"""Planar joint: constrains to a plane.
|
||||
|
||||
Translation: motion along plane normal constrained (1 row).
|
||||
Rotation: rotation about axes in the plane constrained (2 rows).
|
||||
"""
|
||||
normal = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(normal)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -normal
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal)
|
||||
row[col_b_start : col_b_start + 3] = normal
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal)
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "planar_translation",
|
||||
}
|
||||
)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "planar_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Distance constraint: 1 scalar ---
|
||||
|
||||
def _build_distance(self, joint: Joint) -> None:
|
||||
"""Distance constraint: ``||p_b - p_a|| = d``.
|
||||
|
||||
Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0``
|
||||
where ``direction = normalized(p_b - p_a)``.
|
||||
"""
|
||||
p_a = joint.anchor_a
|
||||
p_b = joint.anchor_b
|
||||
diff = p_b - p_a
|
||||
dist = np.linalg.norm(diff)
|
||||
direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -direction
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction)
|
||||
row[col_b_start : col_b_start + 3] = direction
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "distance",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Parallel constraint: 3 rotation constraints ---
|
||||
|
||||
def _build_parallel(self, joint: Joint) -> None:
|
||||
"""Parallel: all relative rotation constrained (same as fixed rotation).
|
||||
|
||||
In practice only 2 of 3 are independent for a single axis, but
|
||||
we emit 3 and let the rank check sort it out.
|
||||
"""
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "parallel_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Perpendicular constraint: 1 angular ---
|
||||
|
||||
def _build_perpendicular(self, joint: Joint) -> None:
|
||||
"""Perpendicular: single dot-product angular constraint."""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "perpendicular",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Universal (Cardan) joint: 3 translation + 1 rotation ---
|
||||
|
||||
def _build_universal(self, joint: Joint) -> None:
|
||||
"""Universal joint: ball + one rotation constraint.
|
||||
|
||||
Allows rotation about two axes, constrains rotation about the third.
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "universal_rotation",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled ---
|
||||
|
||||
def _build_screw(self, joint: Joint) -> None:
|
||||
"""Screw joint: coupled rotation-translation along axis.
|
||||
|
||||
Like cylindrical but with a coupling constraint:
|
||||
``v_axial - pitch * w_axial = 0``
|
||||
"""
|
||||
self._build_cylindrical(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -axis
|
||||
row[col_b_start : col_b_start + 3] = axis
|
||||
row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "screw_coupling",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Utilities ---
|
||||
|
||||
def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Generate two unit vectors perpendicular to *axis* and each other."""
|
||||
if abs(axis[0]) < 0.9:
|
||||
t1 = np.cross(axis, np.array([1.0, 0, 0]))
|
||||
else:
|
||||
t1 = np.cross(axis, np.array([0, 1.0, 0]))
|
||||
t1 /= np.linalg.norm(t1)
|
||||
t2 = np.cross(axis, t1)
|
||||
t2 /= np.linalg.norm(t2)
|
||||
return t1, t2
|
||||
|
||||
def get_jacobian(self) -> np.ndarray:
|
||||
"""Return the full constraint Jacobian matrix."""
|
||||
if not self.jacobian_rows:
|
||||
return np.zeros((0, 6 * self.n_bodies))
|
||||
return np.array(self.jacobian_rows)
|
||||
|
||||
def numerical_rank(self, tol: float = 1e-8) -> int:
|
||||
"""Compute the numerical rank of the constraint Jacobian via SVD.
|
||||
|
||||
This is the number of truly independent scalar constraints,
|
||||
accounting for geometric degeneracies that the combinatorial
|
||||
pebble game cannot detect.
|
||||
"""
|
||||
j = self.get_jacobian()
|
||||
if j.size == 0:
|
||||
return 0
|
||||
sv = np.linalg.svd(j, compute_uv=False)
|
||||
return int(np.sum(sv > tol))
|
||||
|
||||
def find_dependencies(self, tol: float = 1e-8) -> list[int]:
|
||||
"""Identify which constraint rows are numerically dependent.
|
||||
|
||||
Returns indices of rows that can be removed without changing
|
||||
the Jacobian's rank.
|
||||
"""
|
||||
j = self.get_jacobian()
|
||||
if j.size == 0:
|
||||
return []
|
||||
|
||||
n_rows = j.shape[0]
|
||||
dependent: list[int] = []
|
||||
|
||||
current = np.zeros((0, j.shape[1]))
|
||||
current_rank = 0
|
||||
|
||||
for i in range(n_rows):
|
||||
candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :]
|
||||
sv = np.linalg.svd(candidate, compute_uv=False)
|
||||
new_rank = int(np.sum(sv > tol))
|
||||
|
||||
if new_rank > current_rank:
|
||||
current = candidate
|
||||
current_rank = new_rank
|
||||
else:
|
||||
dependent.append(i)
|
||||
|
||||
return dependent
|
||||
394
solver/datagen/labeling.py
Normal file
394
solver/datagen/labeling.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""Ground truth labeling pipeline for synthetic assemblies.
|
||||
|
||||
Produces rich per-constraint, per-joint, per-body, and assembly-level
|
||||
labels by running both the pebble game and Jacobian verification and
|
||||
correlating their results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AssemblyLabels", "label_assembly"]
|
||||
|
||||
_GROUND_ID = -1
|
||||
_SVD_TOL = 1e-8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Label dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintLabel:
|
||||
"""Per scalar-constraint label combining both analysis methods."""
|
||||
|
||||
joint_id: int
|
||||
constraint_idx: int
|
||||
pebble_independent: bool
|
||||
jacobian_independent: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointLabel:
|
||||
"""Aggregated constraint counts for a single joint."""
|
||||
|
||||
joint_id: int
|
||||
independent_count: int
|
||||
redundant_count: int
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BodyDofLabel:
|
||||
"""Per-body DOF signature from nullspace projection."""
|
||||
|
||||
body_id: int
|
||||
translational_dof: int
|
||||
rotational_dof: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyLabel:
|
||||
"""Assembly-wide summary label."""
|
||||
|
||||
classification: str
|
||||
total_dof: int
|
||||
redundant_count: int
|
||||
is_rigid: bool
|
||||
is_minimally_rigid: bool
|
||||
has_degeneracy: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyLabels:
|
||||
"""Complete ground truth labels for an assembly."""
|
||||
|
||||
per_constraint: list[ConstraintLabel]
|
||||
per_joint: list[JointLabel]
|
||||
per_body: list[BodyDofLabel]
|
||||
assembly: AssemblyLabel
|
||||
analysis: ConstraintAnalysis
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"per_constraint": [
|
||||
{
|
||||
"joint_id": c.joint_id,
|
||||
"constraint_idx": c.constraint_idx,
|
||||
"pebble_independent": c.pebble_independent,
|
||||
"jacobian_independent": c.jacobian_independent,
|
||||
}
|
||||
for c in self.per_constraint
|
||||
],
|
||||
"per_joint": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"independent_count": j.independent_count,
|
||||
"redundant_count": j.redundant_count,
|
||||
"total": j.total,
|
||||
}
|
||||
for j in self.per_joint
|
||||
],
|
||||
"per_body": [
|
||||
{
|
||||
"body_id": b.body_id,
|
||||
"translational_dof": b.translational_dof,
|
||||
"rotational_dof": b.rotational_dof,
|
||||
}
|
||||
for b in self.per_body
|
||||
],
|
||||
"assembly": {
|
||||
"classification": self.assembly.classification,
|
||||
"total_dof": self.assembly.total_dof,
|
||||
"redundant_count": self.assembly.redundant_count,
|
||||
"is_rigid": self.assembly.is_rigid,
|
||||
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
||||
"has_degeneracy": self.assembly.has_degeneracy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-body DOF from nullspace projection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_per_body_dof(
|
||||
j_reduced: np.ndarray,
|
||||
body_ids: list[int],
|
||||
ground_body: int | None,
|
||||
body_index: dict[int, int],
|
||||
) -> list[BodyDofLabel]:
|
||||
"""Compute translational and rotational DOF per body.
|
||||
|
||||
Uses SVD nullspace projection: for each body, extract its
|
||||
translational (3 cols) and rotational (3 cols) components
|
||||
from the nullspace basis and compute ranks.
|
||||
"""
|
||||
# Build column index mapping for the reduced Jacobian
|
||||
# (ground body columns have been removed)
|
||||
col_map: dict[int, int] = {}
|
||||
col_idx = 0
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
continue
|
||||
col_map[bid] = col_idx
|
||||
col_idx += 1
|
||||
|
||||
results: list[BodyDofLabel] = []
|
||||
|
||||
if j_reduced.size == 0:
|
||||
# No constraints — every body is fully free
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
else:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3))
|
||||
return results
|
||||
|
||||
# Full SVD to get nullspace
|
||||
_u, s, vh = np.linalg.svd(j_reduced, full_matrices=True)
|
||||
rank = int(np.sum(s > _SVD_TOL))
|
||||
n_cols = j_reduced.shape[1]
|
||||
|
||||
if rank >= n_cols:
|
||||
# Fully constrained — no nullspace
|
||||
for bid in body_ids:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
return results
|
||||
|
||||
# Nullspace basis: rows of Vh beyond the rank
|
||||
nullspace = vh[rank:] # shape: (n_cols - rank, n_cols)
|
||||
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
continue
|
||||
|
||||
idx = col_map[bid]
|
||||
trans_cols = nullspace[:, idx * 6 : idx * 6 + 3]
|
||||
rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6]
|
||||
|
||||
# Rank of each block = DOF in that category
|
||||
if trans_cols.size > 0:
|
||||
sv_t = np.linalg.svd(trans_cols, compute_uv=False)
|
||||
t_dof = int(np.sum(sv_t > _SVD_TOL))
|
||||
else:
|
||||
t_dof = 0
|
||||
|
||||
if rot_cols.size > 0:
|
||||
sv_r = np.linalg.svd(rot_cols, compute_uv=False)
|
||||
r_dof = int(np.sum(sv_r > _SVD_TOL))
|
||||
else:
|
||||
r_dof = 0
|
||||
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main labeling function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def label_assembly(
|
||||
bodies: list[RigidBody],
|
||||
joints: list[Joint],
|
||||
ground_body: int | None = None,
|
||||
) -> AssemblyLabels:
|
||||
"""Produce complete ground truth labels for an assembly.
|
||||
|
||||
Runs both the pebble game and Jacobian verification internally,
|
||||
then correlates their results into per-constraint, per-joint,
|
||||
per-body, and assembly-level labels.
|
||||
|
||||
Args:
|
||||
bodies: Rigid bodies in the assembly.
|
||||
joints: Joints connecting the bodies.
|
||||
ground_body: If set, this body is fixed to the world.
|
||||
|
||||
Returns:
|
||||
AssemblyLabels with full label set and embedded ConstraintAnalysis.
|
||||
"""
|
||||
# ---- Pebble Game ----
|
||||
pg = PebbleGame3D()
|
||||
all_edge_results: list[dict[str, Any]] = []
|
||||
|
||||
if ground_body is not None:
|
||||
pg.add_body(_GROUND_ID)
|
||||
|
||||
for body in bodies:
|
||||
pg.add_body(body.body_id)
|
||||
|
||||
if ground_body is not None:
|
||||
ground_joint = Joint(
|
||||
joint_id=-1,
|
||||
body_a=ground_body,
|
||||
body_b=_GROUND_ID,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
|
||||
for joint in joints:
|
||||
results = pg.add_joint(joint)
|
||||
all_edge_results.extend(results)
|
||||
|
||||
grounded = ground_body is not None
|
||||
combinatorial_independent = len(pg.state.independent_edges)
|
||||
raw_dof = pg.get_dof()
|
||||
ground_offset = 6 if grounded else 0
|
||||
effective_dof = raw_dof - ground_offset
|
||||
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||
|
||||
redundant_count = pg.get_redundant_count()
|
||||
if redundant_count > 0 and effective_internal_dof > 0:
|
||||
classification = "mixed"
|
||||
elif redundant_count > 0:
|
||||
classification = "overconstrained"
|
||||
elif effective_internal_dof > 0:
|
||||
classification = "underconstrained"
|
||||
else:
|
||||
classification = "well-constrained"
|
||||
|
||||
# ---- Jacobian Verification ----
|
||||
verifier = JacobianVerifier(bodies)
|
||||
|
||||
for joint in joints:
|
||||
verifier.add_joint_constraints(joint)
|
||||
|
||||
j_full = verifier.get_jacobian()
|
||||
j_reduced = j_full.copy()
|
||||
if ground_body is not None and j_reduced.size > 0:
|
||||
idx = verifier.body_index[ground_body]
|
||||
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||
j_reduced = np.delete(j_reduced, cols_to_remove, axis=1)
|
||||
|
||||
if j_reduced.size > 0:
|
||||
sv = np.linalg.svd(j_reduced, compute_uv=False)
|
||||
jacobian_rank = int(np.sum(sv > _SVD_TOL))
|
||||
else:
|
||||
jacobian_rank = 0
|
||||
|
||||
n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies)
|
||||
jacobian_nullity = n_cols - jacobian_rank
|
||||
dependent_rows = verifier.find_dependencies()
|
||||
dependent_set = set(dependent_rows)
|
||||
|
||||
trivial_dof = 0 if grounded else 6
|
||||
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||
is_rigid = jacobian_nullity <= trivial_dof
|
||||
is_minimally_rigid = is_rigid and len(dependent_rows) == 0
|
||||
|
||||
# ---- Per-constraint labels ----
|
||||
# Map Jacobian rows to (joint_id, constraint_index).
|
||||
# Rows are added contiguously per joint in the same order as joints.
|
||||
row_to_joint: list[tuple[int, int]] = []
|
||||
for joint in joints:
|
||||
dof = joint.joint_type.dof
|
||||
for ci in range(dof):
|
||||
row_to_joint.append((joint.joint_id, ci))
|
||||
|
||||
per_constraint: list[ConstraintLabel] = []
|
||||
for edge_idx, edge_result in enumerate(all_edge_results):
|
||||
jid = edge_result["joint_id"]
|
||||
ci = edge_result["constraint_index"]
|
||||
pebble_indep = edge_result["independent"]
|
||||
|
||||
# Find matching Jacobian row
|
||||
jacobian_indep = True
|
||||
if edge_idx < len(row_to_joint):
|
||||
row_idx = edge_idx
|
||||
jacobian_indep = row_idx not in dependent_set
|
||||
|
||||
per_constraint.append(
|
||||
ConstraintLabel(
|
||||
joint_id=jid,
|
||||
constraint_idx=ci,
|
||||
pebble_independent=pebble_indep,
|
||||
jacobian_independent=jacobian_indep,
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Per-joint labels ----
|
||||
joint_agg: dict[int, JointLabel] = {}
|
||||
for cl in per_constraint:
|
||||
if cl.joint_id not in joint_agg:
|
||||
joint_agg[cl.joint_id] = JointLabel(
|
||||
joint_id=cl.joint_id,
|
||||
independent_count=0,
|
||||
redundant_count=0,
|
||||
total=0,
|
||||
)
|
||||
jl = joint_agg[cl.joint_id]
|
||||
jl.total += 1
|
||||
if cl.pebble_independent:
|
||||
jl.independent_count += 1
|
||||
else:
|
||||
jl.redundant_count += 1
|
||||
|
||||
per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg]
|
||||
|
||||
# ---- Per-body DOF labels ----
|
||||
body_ids = [b.body_id for b in bodies]
|
||||
per_body = _compute_per_body_dof(
|
||||
j_reduced,
|
||||
body_ids,
|
||||
ground_body,
|
||||
verifier.body_index,
|
||||
)
|
||||
|
||||
# ---- Assembly label ----
|
||||
assembly_label = AssemblyLabel(
|
||||
classification=classification,
|
||||
total_dof=max(0, jacobian_internal_dof),
|
||||
redundant_count=redundant_count,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
has_degeneracy=geometric_degeneracies > 0,
|
||||
)
|
||||
|
||||
# ---- ConstraintAnalysis (for backward compat) ----
|
||||
analysis = ConstraintAnalysis(
|
||||
combinatorial_dof=effective_dof,
|
||||
combinatorial_internal_dof=effective_internal_dof,
|
||||
combinatorial_redundant=redundant_count,
|
||||
combinatorial_classification=classification,
|
||||
per_edge_results=all_edge_results,
|
||||
jacobian_rank=jacobian_rank,
|
||||
jacobian_nullity=jacobian_nullity,
|
||||
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||
numerically_dependent=dependent_rows,
|
||||
geometric_degeneracies=geometric_degeneracies,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
)
|
||||
|
||||
return AssemblyLabels(
|
||||
per_constraint=per_constraint,
|
||||
per_joint=per_joint,
|
||||
per_body=per_body,
|
||||
assembly=assembly_label,
|
||||
analysis=analysis,
|
||||
)
|
||||
258
solver/datagen/pebble_game.py
Normal file
258
solver/datagen/pebble_game.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis.
|
||||
|
||||
Implements the pebble game algorithm adapted for CAD assembly constraint
|
||||
graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints
|
||||
between bodies remove DOF according to their type.
|
||||
|
||||
The pebble game provides a fast combinatorial *necessary* condition for
|
||||
rigidity via Tay's theorem. It does not detect geometric degeneracies —
|
||||
use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient*
|
||||
condition.
|
||||
|
||||
References:
|
||||
- Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008
|
||||
- Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity
|
||||
Percolation: The Pebble Game", J. Comput. Phys., 1997
|
||||
- Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.types import Joint, PebbleState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["PebbleGame3D"]
|
||||
|
||||
|
||||
class PebbleGame3D:
|
||||
"""Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks.
|
||||
|
||||
For body-bar-hinge structures in 3D, Tay's theorem states that a
|
||||
multigraph G on n vertices is generically minimally rigid iff:
|
||||
|E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2)
|
||||
|
||||
The (6,6)-pebble game tests this sparsity condition incrementally.
|
||||
Each vertex starts with 6 pebbles (representing 6 DOF). To insert
|
||||
an edge, we need to collect 6+1=7 pebbles on its two endpoints.
|
||||
If we can, the edge is independent (removes a DOF). If not, it's
|
||||
redundant (overconstrained).
|
||||
|
||||
In the CAD assembly context:
|
||||
- Vertices = rigid bodies
|
||||
- Edges = scalar constraints from joints
|
||||
- A revolute joint (5 DOF removed) maps to 5 multigraph edges
|
||||
- A fixed joint (6 DOF removed) maps to 6 multigraph edges
|
||||
"""
|
||||
|
||||
K = 6 # Pebbles per vertex (DOF per rigid body in 3D)
|
||||
L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.state = PebbleState()
|
||||
self._edge_counter = 0
|
||||
self._bodies: set[int] = set()
|
||||
|
||||
def add_body(self, body_id: int) -> None:
|
||||
"""Register a rigid body (vertex) with K=6 free pebbles."""
|
||||
if body_id in self._bodies:
|
||||
return
|
||||
self._bodies.add(body_id)
|
||||
self.state.free_pebbles[body_id] = self.K
|
||||
self.state.incoming[body_id] = set()
|
||||
self.state.outgoing[body_id] = set()
|
||||
|
||||
def add_joint(self, joint: Joint) -> list[dict[str, Any]]:
|
||||
"""Expand a joint into multigraph edges and test each for independence.
|
||||
|
||||
A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph.
|
||||
Each edge is tested individually via the pebble game.
|
||||
|
||||
Returns a list of dicts, one per scalar constraint, with:
|
||||
- edge_id: int
|
||||
- independent: bool
|
||||
- dof_remaining: int (total free pebbles after this edge)
|
||||
"""
|
||||
self.add_body(joint.body_a)
|
||||
self.add_body(joint.body_b)
|
||||
|
||||
num_constraints = joint.joint_type.dof
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(num_constraints):
|
||||
edge_id = self._edge_counter
|
||||
self._edge_counter += 1
|
||||
|
||||
independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b)
|
||||
total_free = sum(self.state.free_pebbles.values())
|
||||
|
||||
results.append(
|
||||
{
|
||||
"edge_id": edge_id,
|
||||
"joint_id": joint.joint_id,
|
||||
"constraint_index": i,
|
||||
"independent": independent,
|
||||
"dof_remaining": total_free,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool:
|
||||
"""Try to insert a directed edge between u and v.
|
||||
|
||||
The edge is accepted (independent) iff we can collect L+1 = 7
|
||||
pebbles on the two endpoints {u, v} combined.
|
||||
|
||||
If accepted, one pebble is consumed and the edge is directed
|
||||
away from the vertex that gives up the pebble.
|
||||
"""
|
||||
# Count current free pebbles on u and v
|
||||
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||
|
||||
# Try to gather enough pebbles via DFS reachability search
|
||||
if available < self.L + 1:
|
||||
needed = (self.L + 1) - available
|
||||
# Try to free pebbles by searching from u first, then v
|
||||
for target in (u, v):
|
||||
while needed > 0:
|
||||
found = self._search_and_collect(target, frozenset({u, v}))
|
||||
if not found:
|
||||
break
|
||||
needed -= 1
|
||||
|
||||
# Recheck after collection attempts
|
||||
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||
|
||||
if available >= self.L + 1:
|
||||
# Accept: consume a pebble from whichever endpoint has one
|
||||
source = u if self.state.free_pebbles[u] > 0 else v
|
||||
|
||||
self.state.free_pebbles[source] -= 1
|
||||
self.state.directed_edges[edge_id] = (source, v if source == u else u)
|
||||
self.state.outgoing[source].add((edge_id, v if source == u else u))
|
||||
target = v if source == u else u
|
||||
self.state.incoming[target].add((edge_id, source))
|
||||
self.state.independent_edges.add(edge_id)
|
||||
return True
|
||||
else:
|
||||
# Reject: edge is redundant (overconstrained)
|
||||
self.state.redundant_edges.add(edge_id)
|
||||
return False
|
||||
|
||||
def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool:
|
||||
"""DFS to find a free pebble reachable from *target* and move it.
|
||||
|
||||
Follows directed edges *backwards* (from destination to source)
|
||||
to find a vertex with a free pebble that isn't in *forbidden*.
|
||||
When found, reverses the path to move the pebble to *target*.
|
||||
|
||||
Returns True if a pebble was successfully moved to target.
|
||||
"""
|
||||
# BFS/DFS through the directed graph following outgoing edges
|
||||
# from target. An outgoing edge (target -> w) means target spent
|
||||
# a pebble on that edge. If we can find a vertex with a free
|
||||
# pebble, we reverse edges along the path to move it.
|
||||
|
||||
visited: set[int] = set()
|
||||
# Stack: (current_vertex, path_of_edge_ids_to_reverse)
|
||||
stack: list[tuple[int, list[int]]] = [(target, [])]
|
||||
|
||||
while stack:
|
||||
current, path = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
|
||||
# Check if current vertex (not in forbidden, not target)
|
||||
# has a free pebble
|
||||
if (
|
||||
current != target
|
||||
and current not in forbidden
|
||||
and self.state.free_pebbles[current] > 0
|
||||
):
|
||||
# Found a pebble — reverse the path
|
||||
self._reverse_path(path, current)
|
||||
return True
|
||||
|
||||
# Follow outgoing edges from current vertex
|
||||
for eid, neighbor in self.state.outgoing.get(current, set()):
|
||||
if neighbor not in visited:
|
||||
stack.append((neighbor, [*path, eid]))
|
||||
|
||||
return False
|
||||
|
||||
def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None:
|
||||
"""Reverse directed edges along a path, moving a pebble to the start.
|
||||
|
||||
The pebble at *pebble_source* is consumed by the last edge in
|
||||
the path, and a pebble is freed at the path's start vertex.
|
||||
"""
|
||||
if not edge_ids:
|
||||
return
|
||||
|
||||
# Reverse each edge in the path
|
||||
for eid in edge_ids:
|
||||
old_source, old_target = self.state.directed_edges[eid]
|
||||
|
||||
# Remove from adjacency
|
||||
self.state.outgoing[old_source].discard((eid, old_target))
|
||||
self.state.incoming[old_target].discard((eid, old_source))
|
||||
|
||||
# Reverse direction
|
||||
self.state.directed_edges[eid] = (old_target, old_source)
|
||||
self.state.outgoing[old_target].add((eid, old_source))
|
||||
self.state.incoming[old_source].add((eid, old_target))
|
||||
|
||||
# Move pebble counts: source loses one, first vertex in path gains one
|
||||
self.state.free_pebbles[pebble_source] -= 1
|
||||
|
||||
# After all reversals, the vertex at the beginning of the
|
||||
# search path gains a pebble
|
||||
_first_src, first_tgt = self.state.directed_edges[edge_ids[0]]
|
||||
self.state.free_pebbles[first_tgt] += 1
|
||||
|
||||
def get_dof(self) -> int:
|
||||
"""Total remaining DOF = sum of free pebbles.
|
||||
|
||||
For a fully rigid assembly, this should be 6 (the trivial rigid
|
||||
body motions of the whole assembly). Internal DOF = total - 6.
|
||||
"""
|
||||
return sum(self.state.free_pebbles.values())
|
||||
|
||||
def get_internal_dof(self) -> int:
|
||||
"""Internal (non-trivial) degrees of freedom."""
|
||||
return max(0, self.get_dof() - 6)
|
||||
|
||||
def is_rigid(self) -> bool:
|
||||
"""Combinatorial rigidity check: rigid iff at most 6 pebbles remain."""
|
||||
return self.get_dof() <= self.L
|
||||
|
||||
def get_redundant_count(self) -> int:
|
||||
"""Number of redundant (overconstrained) scalar constraints."""
|
||||
return len(self.state.redundant_edges)
|
||||
|
||||
def classify_assembly(self, *, grounded: bool = False) -> str:
|
||||
"""Classify the assembly state.
|
||||
|
||||
Args:
|
||||
grounded: If True, the baseline trivial DOF is 0 (not 6),
|
||||
because the ground body's 6 DOF were removed.
|
||||
"""
|
||||
total_dof = self.get_dof()
|
||||
redundant = self.get_redundant_count()
|
||||
baseline = 0 if grounded else self.L
|
||||
|
||||
if redundant > 0 and total_dof > baseline:
|
||||
return "mixed" # Both under and over-constrained regions
|
||||
elif redundant > 0:
|
||||
return "overconstrained"
|
||||
elif total_dof > baseline:
|
||||
return "underconstrained"
|
||||
elif total_dof == baseline:
|
||||
return "well-constrained"
|
||||
else:
|
||||
return "overconstrained"
|
||||
144
solver/datagen/types.py
Normal file
144
solver/datagen/types.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Shared data types for assembly constraint analysis.
|
||||
|
||||
Types ported from the pebble-game synthetic data generator for reuse
|
||||
across the solver package (data generation, training, inference).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"ConstraintAnalysis",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Joint definitions: each joint type removes a known number of DOF
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JointType(enum.Enum):
|
||||
"""Standard CAD joint types with their DOF-removal counts.
|
||||
|
||||
Each joint between two 6-DOF rigid bodies removes a specific number
|
||||
of relative degrees of freedom. In the body-bar-hinge multigraph
|
||||
representation, each joint maps to a number of edges equal to the
|
||||
DOF it removes.
|
||||
|
||||
Values are ``(ordinal, dof_removed)`` tuples so that joint types
|
||||
sharing the same DOF count remain distinct enum members. Use the
|
||||
:attr:`dof` property to get the scalar constraint count.
|
||||
"""
|
||||
|
||||
FIXED = (0, 6) # Locks all relative motion
|
||||
REVOLUTE = (1, 5) # Allows rotation about one axis
|
||||
CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis
|
||||
SLIDER = (3, 5) # Allows translation along one axis (prismatic)
|
||||
BALL = (4, 3) # Allows rotation about a point (spherical)
|
||||
PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane
|
||||
SCREW = (6, 5) # Coupled rotation-translation (helical)
|
||||
UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint)
|
||||
PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints)
|
||||
PERPENDICULAR = (9, 1) # Single angular constraint
|
||||
DISTANCE = (10, 1) # Single scalar distance constraint
|
||||
|
||||
@property
|
||||
def dof(self) -> int:
|
||||
"""Number of scalar constraints (DOF removed) by this joint type."""
|
||||
return self.value[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RigidBody:
|
||||
"""A rigid body in the assembly with pose and geometry info."""
|
||||
|
||||
body_id: int
|
||||
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
|
||||
|
||||
# Anchor points for joints, in local frame
|
||||
# Populated when joints reference specific geometry
|
||||
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Joint:
|
||||
"""A joint connecting two rigid bodies."""
|
||||
|
||||
joint_id: int
|
||||
body_a: int # Index of first body
|
||||
body_b: int # Index of second body
|
||||
joint_type: JointType
|
||||
|
||||
# Joint parameters in world frame
|
||||
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
axis: np.ndarray = field(
|
||||
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
|
||||
# For screw joints
|
||||
pitch: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PebbleState:
|
||||
"""Tracks the state of the pebble game on the multigraph."""
|
||||
|
||||
# Number of free pebbles per body (vertex). Starts at 6.
|
||||
free_pebbles: dict[int, int] = field(default_factory=dict)
|
||||
|
||||
# Directed edges: edge_id -> (source_body, target_body)
|
||||
# Edge is directed away from the body that "spent" a pebble.
|
||||
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
|
||||
|
||||
# Track which edges are independent vs redundant
|
||||
independent_edges: set[int] = field(default_factory=set)
|
||||
redundant_edges: set[int] = field(default_factory=set)
|
||||
|
||||
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
|
||||
# Following directed edges *towards* a body (incoming edges)
|
||||
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
# Outgoing edges from a body
|
||||
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintAnalysis:
|
||||
"""Results of analyzing an assembly's constraint system."""
|
||||
|
||||
# Pebble game (combinatorial) results
|
||||
combinatorial_dof: int
|
||||
combinatorial_internal_dof: int
|
||||
combinatorial_redundant: int
|
||||
combinatorial_classification: str
|
||||
per_edge_results: list[dict[str, Any]]
|
||||
|
||||
# Numerical (Jacobian) results
|
||||
jacobian_rank: int
|
||||
jacobian_nullity: int # = 6n - rank = total DOF
|
||||
jacobian_internal_dof: int # = nullity - 6
|
||||
numerically_dependent: list[int]
|
||||
|
||||
# Combined
|
||||
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
|
||||
is_rigid: bool
|
||||
is_minimally_rigid: bool
|
||||
0
solver/datasets/__init__.py
Normal file
0
solver/datasets/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
0
solver/models/__init__.py
Normal file
0
solver/models/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for solver.datagen.analysis -- combined analysis function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _two_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
def _triangle_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([1.0, 1.7, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTwoBodiesRevolute:
|
||||
"""Demo scenario 1: two bodies connected by a revolute joint."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _two_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.jacobian_internal_dof == 1
|
||||
|
||||
def test_not_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert not result.is_rigid
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_no_redundant(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_redundant == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTwoBodiesFixed:
|
||||
"""Demo scenario 2: two bodies connected by a fixed joint."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _two_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.jacobian_internal_dof == 0
|
||||
|
||||
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_rigid
|
||||
|
||||
def test_minimally_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_minimally_rigid
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification == "well-constrained"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3: Triangle with revolute joints (overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriangleRevolute:
|
||||
"""Demo scenario 3: triangle of 3 bodies + 3 revolute joints."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _triangle_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
1,
|
||||
body_a=1,
|
||||
body_b=2,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.5, 0.85, 0.0]),
|
||||
anchor_b=np.array([1.5, 0.85, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
2,
|
||||
body_a=2,
|
||||
body_b=0,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([0.5, 0.85, 0.0]),
|
||||
anchor_b=np.array([0.5, 0.85, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_has_redundant(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_redundant > 0
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification in ("overconstrained", "mixed")
|
||||
|
||||
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_rigid
|
||||
|
||||
def test_numerically_dependent(self, result: ConstraintAnalysis) -> None:
|
||||
assert len(result.numerically_dependent) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 4: Parallel revolute axes (geometric degeneracy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelRevoluteAxes:
|
||||
"""Demo scenario 4: parallel revolute axes create geometric degeneracies."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||
]
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
1,
|
||||
body_a=1,
|
||||
body_b=2,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([3.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([3.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None:
|
||||
"""Parallel axes produce at least one geometric degeneracy."""
|
||||
assert result.geometric_degeneracies > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoJoints:
|
||||
"""Assembly with bodies but no joints."""
|
||||
|
||||
def test_all_dof_free(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
result = analyze_assembly(bodies, [], ground_body=0)
|
||||
# Body 1 is completely free (6 DOF), body 0 is grounded
|
||||
assert result.jacobian_internal_dof > 0
|
||||
assert not result.is_rigid
|
||||
|
||||
def test_ungrounded(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
result = analyze_assembly(bodies, [])
|
||||
assert result.combinatorial_classification == "underconstrained"
|
||||
|
||||
|
||||
class TestReturnType:
|
||||
"""Verify the return object is a proper ConstraintAnalysis."""
|
||||
|
||||
def test_instance(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
joints = [Joint(0, 0, 1, JointType.FIXED)]
|
||||
result = analyze_assembly(bodies, joints)
|
||||
assert isinstance(result, ConstraintAnalysis)
|
||||
|
||||
def test_per_edge_results_populated(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
joints = [Joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
result = analyze_assembly(bodies, joints)
|
||||
assert len(result.per_edge_results) == 5
|
||||
337
tests/datagen/test_dataset.py
Normal file
337
tests/datagen/test_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
_derive_shard_seed,
|
||||
_parse_scalar,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetConfig:
|
||||
"""DatasetConfig construction and defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
cfg = DatasetConfig()
|
||||
assert cfg.num_assemblies == 100_000
|
||||
assert cfg.seed == 42
|
||||
assert cfg.shard_size == 1000
|
||||
assert cfg.num_workers == 4
|
||||
|
||||
def test_from_dict_flat(self) -> None:
|
||||
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.num_assemblies == 500
|
||||
assert cfg.seed == 123
|
||||
|
||||
def test_from_dict_nested_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 3
|
||||
assert cfg.body_count_max == 20
|
||||
|
||||
def test_from_dict_flat_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 5
|
||||
assert cfg.body_count_max == 30
|
||||
|
||||
def test_from_dict_complexity_distribution(self) -> None:
|
||||
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
|
||||
|
||||
def test_from_dict_templates(self) -> None:
|
||||
d: dict[str, Any] = {"templates": ["chain", "tree"]}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.templates == ["chain", "tree"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseScalar:
|
||||
"""_parse_scalar handles different value types."""
|
||||
|
||||
def test_int(self) -> None:
|
||||
assert _parse_scalar("42") == 42
|
||||
|
||||
def test_float(self) -> None:
|
||||
assert _parse_scalar("3.14") == 3.14
|
||||
|
||||
def test_bool_true(self) -> None:
|
||||
assert _parse_scalar("true") is True
|
||||
|
||||
def test_bool_false(self) -> None:
|
||||
assert _parse_scalar("false") is False
|
||||
|
||||
def test_string(self) -> None:
|
||||
assert _parse_scalar("hello") == "hello"
|
||||
|
||||
def test_inline_comment(self) -> None:
|
||||
assert _parse_scalar("0.4 # some comment") == 0.4
|
||||
|
||||
|
||||
class TestParseSimpleYaml:
|
||||
"""parse_simple_yaml handles the synthetic.yaml format."""
|
||||
|
||||
def test_flat_scalars(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["name"] == "test"
|
||||
assert result["num"] == 42
|
||||
assert result["ratio"] == 0.5
|
||||
|
||||
def test_nested_dict(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["body_count"] == {"min": 2, "max": 50}
|
||||
|
||||
def test_list(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["templates"] == ["chain", "tree", "loop"]
|
||||
|
||||
def test_inline_comments(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["dist"]["simple"] == 0.4
|
||||
|
||||
def test_synthetic_yaml(self) -> None:
|
||||
"""Parse the actual project config."""
|
||||
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
|
||||
assert result["name"] == "synthetic"
|
||||
assert result["num_assemblies"] == 100000
|
||||
assert isinstance(result["complexity_distribution"], dict)
|
||||
assert isinstance(result["templates"], list)
|
||||
assert result["shard_size"] == 1000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShardSeedDerivation:
|
||||
"""_derive_shard_seed is deterministic and unique per shard."""
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 0)
|
||||
assert s1 == s2
|
||||
|
||||
def test_different_shards(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 1)
|
||||
assert s1 != s2
|
||||
|
||||
def test_different_global_seeds(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(99, 0)
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetGenerator — small end-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetGenerator:
|
||||
"""End-to-end tests with small datasets."""
|
||||
|
||||
def test_small_generation(self, tmp_path: Path) -> None:
|
||||
"""Generate 10 examples in a single shard."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
assert shards_dir.exists()
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 1
|
||||
|
||||
index_file = tmp_path / "output" / "index.json"
|
||||
assert index_file.exists()
|
||||
index = json.loads(index_file.read_text())
|
||||
assert index["total_assemblies"] == 10
|
||||
|
||||
stats_file = tmp_path / "output" / "stats.json"
|
||||
assert stats_file.exists()
|
||||
|
||||
def test_multi_shard(self, tmp_path: Path) -> None:
|
||||
"""Generate 20 examples across 2 shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 2
|
||||
|
||||
def test_resume_skips_completed(self, tmp_path: Path) -> None:
|
||||
"""Resume skips already-completed shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Record shard modification times
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
|
||||
|
||||
# Remove stats (simulate incomplete) and re-run
|
||||
(tmp_path / "output" / "stats.json").unlink()
|
||||
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Shards should NOT have been regenerated
|
||||
for p in shards_dir.glob("shard_*"):
|
||||
assert p.stat().st_mtime == mtimes[p.name]
|
||||
|
||||
# Stats should be regenerated
|
||||
assert (tmp_path / "output" / "stats.json").exists()
|
||||
|
||||
def test_checkpoint_removed(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is cleaned up after completion."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
checkpoint = tmp_path / "output" / ".checkpoint.json"
|
||||
assert not checkpoint.exists()
|
||||
|
||||
def test_stats_structure(self, tmp_path: Path) -> None:
|
||||
"""stats.json has expected top-level keys."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
|
||||
assert stats["total_examples"] == 10
|
||||
assert "classification_distribution" in stats
|
||||
assert "body_count_histogram" in stats
|
||||
assert "joint_type_distribution" in stats
|
||||
assert "dof_statistics" in stats
|
||||
assert "geometric_degeneracy" in stats
|
||||
assert "rigidity" in stats
|
||||
|
||||
def test_index_structure(self, tmp_path: Path) -> None:
|
||||
"""index.json has expected format."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=15,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
index = json.loads((tmp_path / "output" / "index.json").read_text())
|
||||
assert index["format_version"] == 1
|
||||
assert index["total_assemblies"] == 15
|
||||
assert index["total_shards"] == 2
|
||||
assert "shards" in index
|
||||
for _name, info in index["shards"].items():
|
||||
assert "start_id" in info
|
||||
assert "count" in info
|
||||
|
||||
def test_deterministic_output(self, tmp_path: Path) -> None:
|
||||
"""Same seed produces same results."""
|
||||
for run_dir in ("run1", "run2"):
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / run_dir),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
|
||||
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
|
||||
assert s1["total_examples"] == s2["total_examples"]
|
||||
assert s1["classification_distribution"] == s2["classification_distribution"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Run the script via subprocess."""
|
||||
|
||||
def test_argparse_mode(self, tmp_path: Path) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/generate_synthetic.py",
|
||||
"--num-assemblies",
|
||||
"5",
|
||||
"--output-dir",
|
||||
str(tmp_path / "cli_out"),
|
||||
"--shard-size",
|
||||
"5",
|
||||
"--num-workers",
|
||||
"1",
|
||||
"--seed",
|
||||
"42",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/developer",
|
||||
timeout=120,
|
||||
env={**os.environ, "PYTHONPATH": "/home/developer"},
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
assert (tmp_path / "cli_out" / "index.json").exists()
|
||||
assert (tmp_path / "cli_out" / "stats.json").exists()
|
||||
682
tests/datagen/test_generator.py
Normal file
682
tests/datagen/test_generator.py
Normal file
@@ -0,0 +1,682 @@
|
||||
"""Tests for solver.datagen.generator -- synthetic assembly generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChainAssembly:
|
||||
"""generate_chain_assembly produces valid underconstrained chains."""
|
||||
|
||||
def test_returns_three_tuple(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
bodies, joints, _analysis = gen.generate_chain_assembly(4)
|
||||
assert len(bodies) == 4
|
||||
assert len(joints) == 3
|
||||
|
||||
def test_chain_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_chain_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
bodies, _, _ = gen.generate_chain_assembly(5)
|
||||
ids = [b.body_id for b in bodies]
|
||||
assert ids == [0, 1, 2, 3, 4]
|
||||
|
||||
def test_chain_joint_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(4)
|
||||
for i, j in enumerate(joints):
|
||||
assert j.body_a == i
|
||||
assert j.body_b == i + 1
|
||||
|
||||
def test_chain_custom_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
3,
|
||||
joint_type=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
|
||||
class TestRigidAssembly:
|
||||
"""generate_rigid_assembly produces rigid assemblies."""
|
||||
|
||||
def test_rigid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_rigid_assembly(4)
|
||||
assert analysis.is_rigid
|
||||
|
||||
def test_spanning_tree_structure(self) -> None:
|
||||
"""n bodies should have at least n-1 joints (spanning tree)."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_rigid_assembly(5)
|
||||
assert len(joints) >= len(bodies) - 1
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
"""Same seed produces same results."""
|
||||
g1 = SyntheticAssemblyGenerator(seed=99)
|
||||
g2 = SyntheticAssemblyGenerator(seed=99)
|
||||
_, j1, a1 = g1.generate_rigid_assembly(4)
|
||||
_, j2, a2 = g2.generate_rigid_assembly(4)
|
||||
assert a1.jacobian_rank == a2.jacobian_rank
|
||||
assert len(j1) == len(j2)
|
||||
|
||||
|
||||
class TestOverconstrainedAssembly:
|
||||
"""generate_overconstrained_assembly adds redundant constraints."""
|
||||
|
||||
def test_has_redundant(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=2,
|
||||
)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_extra_joints_added(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_base, _ = gen.generate_rigid_assembly(4)
|
||||
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=3,
|
||||
)
|
||||
# Overconstrained has base joints + extra
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTreeAssembly:
|
||||
"""generate_tree_assembly produces tree-structured assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_tree_assembly(6)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(
|
||||
10,
|
||||
branching_factor=2,
|
||||
)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
# With 9 joints and 3 types, very likely to use at least 2
|
||||
assert len(used) >= 2
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(
|
||||
5,
|
||||
joint_types=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(7)
|
||||
assert [b.body_id for b in bodies] == list(range(7))
|
||||
|
||||
|
||||
class TestLoopAssembly:
|
||||
"""generate_loop_assembly produces closed-loop assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_loop_assembly(5)
|
||||
assert len(bodies) == 5
|
||||
assert len(joints) == 5 # n joints for n bodies
|
||||
|
||||
def test_has_redundancy(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_loop_assembly(5)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_wrap_around_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(4)
|
||||
edges = {(j.body_a, j.body_b) for j in joints}
|
||||
assert (0, 1) in edges
|
||||
assert (1, 2) in edges
|
||||
assert (2, 3) in edges
|
||||
assert (3, 0) in edges # wrap-around
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
gen.generate_loop_assembly(2)
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
assert len(used) >= 2
|
||||
|
||||
|
||||
class TestStarAssembly:
|
||||
"""generate_star_assembly produces hub-and-spoke assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_star_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_all_joints_connect_to_hub(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(6)
|
||||
for j in joints:
|
||||
assert j.body_a == 0 or j.body_b == 0
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_star_assembly(5)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 2"):
|
||||
gen.generate_star_assembly(1)
|
||||
|
||||
def test_hub_at_origin(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(4)
|
||||
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
|
||||
|
||||
|
||||
class TestMixedAssembly:
|
||||
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
8,
|
||||
edge_density=0.3,
|
||||
)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=0.0,
|
||||
)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=1.5)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=-0.1)
|
||||
|
||||
def test_no_duplicate_edges(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
|
||||
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
|
||||
assert len(edges) == len(set(edges))
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=1.0,
|
||||
)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAxisStrategy:
|
||||
"""Axis sampling strategies produce valid unit vectors."""
|
||||
|
||||
def test_cardinal_axis_from_six(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
|
||||
expected = {
|
||||
(1, 0, 0),
|
||||
(-1, 0, 0),
|
||||
(0, 1, 0),
|
||||
(0, -1, 0),
|
||||
(0, 0, 1),
|
||||
(0, 0, -1),
|
||||
}
|
||||
assert axes == expected
|
||||
|
||||
def test_random_axis_unit_norm(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
for _ in range(50):
|
||||
axis = gen._sample_axis("random")
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
|
||||
def test_near_parallel_close_to_base(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
base = np.array([0.0, 0.0, 1.0])
|
||||
for _ in range(50):
|
||||
axis = gen._near_parallel_axis(base)
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
assert np.dot(axis, base) > 0.95
|
||||
|
||||
def test_sample_axis_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("cardinal")
|
||||
cardinals = [
|
||||
np.array(v, dtype=float)
|
||||
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
|
||||
]
|
||||
assert any(np.allclose(axis, c) for c in cardinals)
|
||||
|
||||
def test_sample_axis_near_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("near_parallel")
|
||||
z = np.array([0.0, 0.0, 1.0])
|
||||
assert np.dot(axis, z) > 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: orientations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRandomOrientations:
|
||||
"""Bodies should have non-identity orientations."""
|
||||
|
||||
def test_bodies_have_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(5)
|
||||
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
|
||||
assert non_identity >= 3
|
||||
|
||||
def test_orientations_are_valid_rotations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(6)
|
||||
for b in bodies:
|
||||
r = b.orientation
|
||||
# R^T R == I
|
||||
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
|
||||
# det(R) == 1
|
||||
assert abs(np.linalg.det(r) - 1.0) < 1e-10
|
||||
|
||||
def test_all_generators_set_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
# Chain
|
||||
bodies, _, _ = gen.generate_chain_assembly(3)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Loop
|
||||
bodies, _, _ = gen.generate_loop_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Mixed
|
||||
bodies, _, _ = gen.generate_mixed_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: grounded parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGroundedParameter:
|
||||
"""Grounded parameter controls ground_body in analysis."""
|
||||
|
||||
def test_chain_grounded_default(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_dof >= 0
|
||||
|
||||
def test_chain_floating(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(
|
||||
4,
|
||||
grounded=False,
|
||||
)
|
||||
# Floating: 6 trivial DOF not subtracted by ground
|
||||
assert analysis.combinatorial_dof >= 6
|
||||
|
||||
def test_floating_vs_grounded_dof_difference(self) -> None:
|
||||
gen1 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
|
||||
# Floating should have higher DOF due to missing ground constraint
|
||||
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gen_method",
|
||||
[
|
||||
"generate_chain_assembly",
|
||||
"generate_rigid_assembly",
|
||||
"generate_tree_assembly",
|
||||
"generate_loop_assembly",
|
||||
"generate_star_assembly",
|
||||
"generate_mixed_assembly",
|
||||
],
|
||||
)
|
||||
def test_all_generators_accept_grounded(
|
||||
self,
|
||||
gen_method: str,
|
||||
) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
method = getattr(gen, gen_method)
|
||||
n = 4
|
||||
# Should not raise
|
||||
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
|
||||
method(n, grounded=False)
|
||||
else:
|
||||
method(n, grounded=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: parallel axis injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelAxisInjection:
|
||||
"""parallel_axis_prob causes shared axis direction."""
|
||||
|
||||
def test_parallel_axes_similar(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
# Near-parallel: |dot| close to 1
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_zero_prob_no_forced_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=0.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
|
||||
# With 5 random axes, extremely unlikely all are parallel
|
||||
assert min(dots) < 0.95
|
||||
|
||||
def test_parallel_on_loop(self) -> None:
|
||||
"""Parallel axes on a loop assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_parallel_on_star(self) -> None:
|
||||
"""Parallel axes on a star assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComplexityTiers:
|
||||
"""Complexity tier parameter on batch generation."""
|
||||
|
||||
def test_simple_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="simple")
|
||||
lo, hi = COMPLEXITY_RANGES["simple"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_medium_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_complex_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(3, complexity_tier="complex")
|
||||
lo, hi = COMPLEXITY_RANGES["complex"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
n_bodies_range=(2, 3),
|
||||
complexity_tier="medium",
|
||||
)
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrainingBatch:
|
||||
"""generate_training_batch produces well-structured examples."""
|
||||
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"grounded",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"body_orientations",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"labels",
|
||||
"assembly_classification",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"internal_dof",
|
||||
"geometric_degeneracies",
|
||||
}
|
||||
|
||||
VALID_GEN_TYPES: ClassVar[set[str]] = {
|
||||
"chain",
|
||||
"rigid",
|
||||
"overconstrained",
|
||||
"tree",
|
||||
"loop",
|
||||
"star",
|
||||
"mixed",
|
||||
}
|
||||
|
||||
def test_batch_size(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20)
|
||||
assert len(batch) == 20
|
||||
|
||||
def test_example_keys(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert set(ex.keys()) == self.EXPECTED_KEYS
|
||||
|
||||
def test_example_ids_sequential(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(15)
|
||||
assert [ex["example_id"] for ex in batch] == list(range(15))
|
||||
|
||||
def test_generator_type_valid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(50)
|
||||
for ex in batch:
|
||||
assert ex["generator_type"] in self.VALID_GEN_TYPES
|
||||
|
||||
def test_generator_type_diversity(self) -> None:
|
||||
"""100-sample batch should use at least 5 of 7 generator types."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100)
|
||||
types = {ex["generator_type"] for ex in batch}
|
||||
assert len(types) >= 5
|
||||
|
||||
def test_default_body_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
# default (3, 8), but loop/star may clamp
|
||||
assert 2 <= ex["n_bodies"] <= 7
|
||||
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
assert total == label["total_constraints"]
|
||||
|
||||
def test_body_orientations_present(self) -> None:
|
||||
"""Each example includes body_orientations as 3x3 lists."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
orients = ex["body_orientations"]
|
||||
assert len(orients) == ex["n_bodies"]
|
||||
for o in orients:
|
||||
assert len(o) == 3
|
||||
assert len(o[0]) == 3
|
||||
|
||||
def test_labels_structure(self) -> None:
|
||||
"""Each example has labels dict with expected sub-keys."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
labels = ex["labels"]
|
||||
assert "per_constraint" in labels
|
||||
assert "per_joint" in labels
|
||||
assert "per_body" in labels
|
||||
assert "assembly" in labels
|
||||
|
||||
def test_grounded_field_present(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert isinstance(ex["grounded"], bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch grounded ratio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchGroundedRatio:
|
||||
"""grounded_ratio controls the mix in batch generation."""
|
||||
|
||||
def test_all_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
|
||||
assert all(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_none_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
|
||||
assert not any(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_mixed_ratio(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
|
||||
grounded_count = sum(1 for ex in batch if ex["grounded"])
|
||||
# With 100 samples and p=0.5, should be roughly 50 +/- 20
|
||||
assert 20 < grounded_count < 80
|
||||
|
||||
def test_batch_axis_strategy_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
axis_strategy="cardinal",
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
def test_batch_parallel_axis_prob(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
parallel_axis_prob=0.5,
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed reproducibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSeedReproducibility:
|
||||
"""Different seeds produce different results."""
|
||||
|
||||
def test_different_seeds_differ(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=1)
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
b2 = g2.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
r2 = [ex["is_rigid"] for ex in b2]
|
||||
assert c1 != c2 or r1 != r2
|
||||
|
||||
def test_same_seed_identical(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=123)
|
||||
g2 = SyntheticAssemblyGenerator(seed=123)
|
||||
b1, j1, _ = g1.generate_tree_assembly(5)
|
||||
b2, j2, _ = g2.generate_tree_assembly(5)
|
||||
for a, b in zip(b1, b2, strict=True):
|
||||
np.testing.assert_array_almost_equal(a.position, b.position)
|
||||
assert len(j1) == len(j2)
|
||||
267
tests/datagen/test_jacobian.py
Normal file
267
tests/datagen/test_jacobian.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for solver.datagen.jacobian -- Jacobian rank verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _two_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
def _three_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJacobianShape:
|
||||
"""Verify Jacobian matrix dimensions for each joint type."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"joint_type,expected_rows",
|
||||
[
|
||||
(JointType.FIXED, 6),
|
||||
(JointType.REVOLUTE, 5),
|
||||
(JointType.CYLINDRICAL, 4),
|
||||
(JointType.SLIDER, 5),
|
||||
(JointType.BALL, 3),
|
||||
(JointType.PLANAR, 3),
|
||||
(JointType.SCREW, 5),
|
||||
(JointType.UNIVERSAL, 4),
|
||||
(JointType.PARALLEL, 3),
|
||||
(JointType.PERPENDICULAR, 1),
|
||||
(JointType.DISTANCE, 1),
|
||||
],
|
||||
)
|
||||
def test_row_count(self, joint_type: JointType, expected_rows: int) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
joint = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=joint_type,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
n_added = v.add_joint_constraints(joint)
|
||||
assert n_added == expected_rows
|
||||
|
||||
j = v.get_jacobian()
|
||||
assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols
|
||||
|
||||
|
||||
class TestNumericalRank:
|
||||
"""Numerical rank checks for known configurations."""
|
||||
|
||||
def test_fixed_joint_rank_six(self) -> None:
|
||||
"""Fixed joint between 2 bodies: rank = 6."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 6
|
||||
|
||||
def test_revolute_joint_rank_five(self) -> None:
|
||||
"""Revolute joint between 2 bodies: rank = 5."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 5
|
||||
|
||||
def test_ball_joint_rank_three(self) -> None:
|
||||
"""Ball joint between 2 bodies: rank = 3."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.BALL,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 3
|
||||
|
||||
def test_empty_jacobian_rank_zero(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
assert v.numerical_rank() == 0
|
||||
|
||||
|
||||
class TestParallelAxesDegeneracy:
|
||||
"""Parallel revolute axes create geometric dependencies."""
|
||||
|
||||
def _four_body_loop(self) -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([2.0, 2.0, 0.0])),
|
||||
RigidBody(3, position=np.array([0.0, 2.0, 0.0])),
|
||||
]
|
||||
|
||||
def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]:
|
||||
pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])]
|
||||
return [
|
||||
Joint(
|
||||
joint_id=i,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array(anc, dtype=float),
|
||||
anchor_b=np.array(anc, dtype=float),
|
||||
axis=axes[i],
|
||||
)
|
||||
for i, (a, b, anc) in enumerate(pairs)
|
||||
]
|
||||
|
||||
def test_parallel_has_lower_rank(self) -> None:
|
||||
"""4-body closed loop: all-parallel revolute axes produce lower
|
||||
Jacobian rank than mixed axes due to geometric dependency."""
|
||||
bodies = self._four_body_loop()
|
||||
z_axis = np.array([0.0, 0.0, 1.0])
|
||||
|
||||
# All axes parallel to Z
|
||||
v_par = JacobianVerifier(bodies)
|
||||
for j in self._loop_joints([z_axis] * 4):
|
||||
v_par.add_joint_constraints(j)
|
||||
rank_par = v_par.numerical_rank()
|
||||
|
||||
# Mixed axes
|
||||
mixed = [
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([0.0, 1.0, 0.0]),
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([1.0, 0.0, 0.0]),
|
||||
]
|
||||
v_mix = JacobianVerifier(bodies)
|
||||
for j in self._loop_joints(mixed):
|
||||
v_mix.add_joint_constraints(j)
|
||||
rank_mix = v_mix.numerical_rank()
|
||||
|
||||
assert rank_par < rank_mix
|
||||
|
||||
|
||||
class TestFindDependencies:
|
||||
"""Dependency detection."""
|
||||
|
||||
def test_fixed_joint_no_dependencies(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.find_dependencies() == []
|
||||
|
||||
def test_duplicate_fixed_has_dependencies(self) -> None:
|
||||
"""Two fixed joints on same pair: second is fully dependent."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
for jid in range(2):
|
||||
v.add_joint_constraints(
|
||||
Joint(
|
||||
joint_id=jid,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
)
|
||||
deps = v.find_dependencies()
|
||||
assert len(deps) == 6 # Second fixed joint entirely redundant
|
||||
|
||||
def test_empty_no_dependencies(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
assert v.find_dependencies() == []
|
||||
|
||||
|
||||
class TestRowLabels:
|
||||
"""Row label metadata."""
|
||||
|
||||
def test_labels_match_rows(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=7,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert len(v.row_labels) == 5
|
||||
assert all(lab["joint_id"] == 7 for lab in v.row_labels)
|
||||
|
||||
|
||||
class TestPerpendicularPair:
|
||||
"""Internal _perpendicular_pair utility."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"axis",
|
||||
[
|
||||
np.array([1.0, 0.0, 0.0]),
|
||||
np.array([0.0, 1.0, 0.0]),
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([1.0, 1.0, 1.0]) / np.sqrt(3),
|
||||
],
|
||||
)
|
||||
def test_orthonormal(self, axis: np.ndarray) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
t1, t2 = v._perpendicular_pair(axis)
|
||||
|
||||
# All unit length
|
||||
np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12)
|
||||
|
||||
# Mutually perpendicular
|
||||
np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12)
|
||||
346
tests/datagen/test_labeling.py
Normal file
346
tests/datagen/test_labeling.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.labeling import (
|
||||
label_assembly,
|
||||
)
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
|
||||
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
|
||||
|
||||
|
||||
def _make_joint(
|
||||
jid: int,
|
||||
a: int,
|
||||
b: int,
|
||||
jtype: JointType,
|
||||
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
|
||||
) -> Joint:
|
||||
return Joint(
|
||||
joint_id=jid,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=jtype,
|
||||
anchor_a=np.zeros(3),
|
||||
anchor_b=np.zeros(3),
|
||||
axis=np.array(axis),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-constraint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstraintLabels:
|
||||
"""Per-constraint labels combine pebble game and Jacobian results."""
|
||||
|
||||
def test_fixed_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 6
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_revolute_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 5
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_chain_constraint_count(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 10 # 5 + 5
|
||||
|
||||
def test_constraint_joint_ids(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
|
||||
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
|
||||
assert len(j0_constraints) == 5 # revolute
|
||||
assert len(j1_constraints) == 3 # ball
|
||||
|
||||
def test_overconstrained_has_pebble_redundant(self) -> None:
|
||||
"""Triangle with revolute joints: some constraints redundant."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
|
||||
assert pebble_redundant > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-joint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJointLabels:
|
||||
"""Per-joint aggregated labels."""
|
||||
|
||||
def test_fixed_joint_counts(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_joint) == 1
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.joint_id == 0
|
||||
assert jl.independent_count == 6
|
||||
assert jl.redundant_count == 0
|
||||
assert jl.total == 6
|
||||
|
||||
def test_overconstrained_has_redundant_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
|
||||
assert total_redundant > 0
|
||||
|
||||
def test_joint_total_equals_dof(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.BALL)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.total == 3 # ball has 3 DOF
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-body DOF labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBodyDofLabels:
|
||||
"""Per-body DOF signatures from nullspace projection."""
|
||||
|
||||
def test_fixed_joint_grounded_both_zero(self) -> None:
|
||||
"""Two bodies + fixed joint + grounded: both fully constrained."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert bl.translational_dof == 0
|
||||
assert bl.rotational_dof == 0
|
||||
|
||||
def test_revolute_has_rotational_dof(self) -> None:
|
||||
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
# Revolute allows 1 rotation DOF
|
||||
assert b1.rotational_dof >= 1
|
||||
|
||||
def test_dof_bounds(self) -> None:
|
||||
"""All DOF values should be in [0, 3]."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert 0 <= bl.translational_dof <= 3
|
||||
assert 0 <= bl.rotational_dof <= 3
|
||||
|
||||
def test_floating_more_dof_than_grounded(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
grounded = label_assembly(bodies, joints, ground_body=0)
|
||||
floating = label_assembly(bodies, joints, ground_body=None)
|
||||
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
|
||||
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
|
||||
assert f_total > g_total
|
||||
|
||||
def test_grounded_body_zero_dof(self) -> None:
|
||||
"""The grounded body should have 0 DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b0 = next(b for b in labels.per_body if b.body_id == 0)
|
||||
assert b0.translational_dof == 0
|
||||
assert b0.rotational_dof == 0
|
||||
|
||||
def test_body_count_matches(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_body) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assembly label
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAssemblyLabel:
|
||||
"""Assembly-wide summary labels."""
|
||||
|
||||
def test_underconstrained_chain(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
assert labels.assembly.is_rigid is False
|
||||
|
||||
def test_well_constrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "well-constrained"
|
||||
assert labels.assembly.is_rigid is True
|
||||
assert labels.assembly.is_minimally_rigid is True
|
||||
|
||||
def test_overconstrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.redundant_count > 0
|
||||
|
||||
def test_has_degeneracy_with_parallel_axes(self) -> None:
|
||||
"""Parallel revolute axes in a loop create geometric degeneracy."""
|
||||
z_axis = (0.0, 0.0, 1.0)
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.has_degeneracy is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToDict:
|
||||
"""to_dict produces JSON-serializable output."""
|
||||
|
||||
def test_top_level_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d.keys()) == {
|
||||
"per_constraint",
|
||||
"per_joint",
|
||||
"per_body",
|
||||
"assembly",
|
||||
}
|
||||
|
||||
def test_per_constraint_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
for item in d["per_constraint"]:
|
||||
assert set(item.keys()) == {
|
||||
"joint_id",
|
||||
"constraint_idx",
|
||||
"pebble_independent",
|
||||
"jacobian_independent",
|
||||
}
|
||||
|
||||
def test_assembly_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d["assembly"].keys()) == {
|
||||
"classification",
|
||||
"total_dof",
|
||||
"redundant_count",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"has_degeneracy",
|
||||
}
|
||||
|
||||
def test_json_serializable(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
# Should not raise
|
||||
serialized = json.dumps(d)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelAssemblyEdgeCases:
|
||||
"""Edge cases for label_assembly."""
|
||||
|
||||
def test_no_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=0)
|
||||
assert len(labels.per_constraint) == 0
|
||||
assert len(labels.per_joint) == 0
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
# Non-ground body should be fully free
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
assert b1.translational_dof == 3
|
||||
assert b1.rotational_dof == 3
|
||||
|
||||
def test_no_joints_floating(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=None)
|
||||
assert len(labels.per_body) == 1
|
||||
assert labels.per_body[0].translational_dof == 3
|
||||
assert labels.per_body[0].rotational_dof == 3
|
||||
|
||||
def test_analysis_embedded(self) -> None:
|
||||
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
analysis = labels.analysis
|
||||
assert hasattr(analysis, "combinatorial_classification")
|
||||
assert hasattr(analysis, "jacobian_rank")
|
||||
assert hasattr(analysis, "is_rigid")
|
||||
206
tests/datagen/test_pebble_game.py
Normal file
206
tests/datagen/test_pebble_game.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import Joint, JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint:
|
||||
"""Shorthand for a revolute joint between bodies *a* and *b*."""
|
||||
if axis is None:
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
return Joint(
|
||||
joint_id=jid,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def _fixed(jid: int, a: int, b: int) -> Joint:
|
||||
return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddBody:
|
||||
"""Body registration basics."""
|
||||
|
||||
def test_single_body_six_pebbles(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_body(0)
|
||||
assert pg.state.free_pebbles[0] == 6
|
||||
|
||||
def test_duplicate_body_no_op(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_body(0)
|
||||
pg.add_body(0)
|
||||
assert pg.state.free_pebbles[0] == 6
|
||||
|
||||
def test_multiple_bodies(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(5):
|
||||
pg.add_body(i)
|
||||
assert pg.get_dof() == 30 # 5 * 6
|
||||
|
||||
|
||||
class TestAddJoint:
|
||||
"""Joint insertion and DOF accounting."""
|
||||
|
||||
def test_revolute_removes_five_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
assert len(results) == 5 # 5 scalar constraints
|
||||
assert all(r["independent"] for r in results)
|
||||
# 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles
|
||||
assert pg.get_dof() == 7
|
||||
|
||||
def test_fixed_removes_six_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_fixed(0, 0, 1))
|
||||
assert len(results) == 6
|
||||
assert all(r["independent"] for r in results)
|
||||
assert pg.get_dof() == 6
|
||||
|
||||
def test_ball_removes_three_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL)
|
||||
results = pg.add_joint(j)
|
||||
assert len(results) == 3
|
||||
assert all(r["independent"] for r in results)
|
||||
assert pg.get_dof() == 9
|
||||
|
||||
|
||||
class TestTwoBodiesRevolute:
|
||||
"""Two bodies connected by a revolute -- demo scenario 1."""
|
||||
|
||||
def test_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
# Total DOF = 7, internal = 7 - 6 = 1
|
||||
assert pg.get_internal_dof() == 1
|
||||
|
||||
def test_not_rigid(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
assert not pg.is_rigid()
|
||||
|
||||
def test_classification(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
assert pg.classify_assembly() == "underconstrained"
|
||||
|
||||
|
||||
class TestTwoBodiesFixed:
|
||||
"""Two bodies + fixed joint -- demo scenario 2."""
|
||||
|
||||
def test_zero_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.get_internal_dof() == 0
|
||||
|
||||
def test_rigid(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.is_rigid()
|
||||
|
||||
def test_well_constrained(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.classify_assembly() == "well-constrained"
|
||||
|
||||
|
||||
class TestTriangleRevolute:
|
||||
"""Triangle of 3 bodies with revolute joints -- demo scenario 3."""
|
||||
|
||||
@pytest.fixture()
|
||||
def pg(self) -> PebbleGame3D:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
pg.add_joint(_revolute(1, 1, 2))
|
||||
pg.add_joint(_revolute(2, 2, 0))
|
||||
return pg
|
||||
|
||||
def test_has_redundant_edges(self, pg: PebbleGame3D) -> None:
|
||||
assert pg.get_redundant_count() > 0
|
||||
|
||||
def test_classification_overconstrained(self, pg: PebbleGame3D) -> None:
|
||||
# 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed)
|
||||
assert pg.classify_assembly() in ("overconstrained", "mixed")
|
||||
|
||||
def test_rigid(self, pg: PebbleGame3D) -> None:
|
||||
assert pg.is_rigid()
|
||||
|
||||
|
||||
class TestChainNotRigid:
|
||||
"""A serial chain of 4 bodies with revolute joints is never rigid."""
|
||||
|
||||
def test_chain_underconstrained(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(3):
|
||||
pg.add_joint(_revolute(i, i, i + 1))
|
||||
assert not pg.is_rigid()
|
||||
assert pg.classify_assembly() == "underconstrained"
|
||||
|
||||
def test_chain_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(3):
|
||||
pg.add_joint(_revolute(i, i, i + 1))
|
||||
# 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3
|
||||
assert pg.get_internal_dof() == 3
|
||||
|
||||
|
||||
class TestEdgeResults:
|
||||
"""Result dicts returned by add_joint."""
|
||||
|
||||
def test_result_keys(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"}
|
||||
for r in results:
|
||||
assert set(r.keys()) == expected_keys
|
||||
|
||||
def test_edge_ids_sequential(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
r1 = pg.add_joint(_revolute(0, 0, 1))
|
||||
r2 = pg.add_joint(_revolute(1, 1, 2))
|
||||
all_ids = [r["edge_id"] for r in r1 + r2]
|
||||
assert all_ids == list(range(10))
|
||||
|
||||
def test_dof_remaining_monotonic(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
dofs = [r["dof_remaining"] for r in results]
|
||||
# Should be non-increasing (each independent edge removes a pebble)
|
||||
for a, b in itertools.pairwise(dofs):
|
||||
assert a >= b
|
||||
|
||||
|
||||
class TestGroundedClassification:
|
||||
"""classify_assembly with grounded=True."""
|
||||
|
||||
def test_grounded_baseline_zero(self) -> None:
|
||||
"""With grounded=True the baseline is 0 (not 6)."""
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
# Ungrounded: well-constrained (6 pebbles = baseline 6)
|
||||
assert pg.classify_assembly(grounded=False) == "well-constrained"
|
||||
# Grounded: the 6 remaining pebbles on body 1 exceed baseline 0,
|
||||
# so the raw pebble game (without a virtual ground body) sees this
|
||||
# as underconstrained. The analysis function handles this properly
|
||||
# by adding a virtual ground body.
|
||||
assert pg.classify_assembly(grounded=True) == "underconstrained"
|
||||
163
tests/datagen/test_types.py
Normal file
163
tests/datagen/test_types.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for solver.datagen.types -- shared data types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
PebbleState,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
|
||||
class TestJointType:
|
||||
"""JointType enum construction and DOF values."""
|
||||
|
||||
EXPECTED_DOF: ClassVar[dict[str, int]] = {
|
||||
"FIXED": 6,
|
||||
"REVOLUTE": 5,
|
||||
"CYLINDRICAL": 4,
|
||||
"SLIDER": 5,
|
||||
"BALL": 3,
|
||||
"PLANAR": 3,
|
||||
"SCREW": 5,
|
||||
"UNIVERSAL": 4,
|
||||
"PARALLEL": 3,
|
||||
"PERPENDICULAR": 1,
|
||||
"DISTANCE": 1,
|
||||
}
|
||||
|
||||
def test_member_count(self) -> None:
|
||||
assert len(JointType) == 11
|
||||
|
||||
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
|
||||
def test_dof_values(self, name: str, dof: int) -> None:
|
||||
assert JointType[name].dof == dof
|
||||
|
||||
def test_access_by_name(self) -> None:
|
||||
assert JointType["REVOLUTE"] is JointType.REVOLUTE
|
||||
|
||||
def test_value_is_tuple(self) -> None:
|
||||
assert JointType.REVOLUTE.value == (1, 5)
|
||||
assert JointType.REVOLUTE.dof == 5
|
||||
|
||||
|
||||
class TestRigidBody:
|
||||
"""RigidBody dataclass defaults and construction."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
body = RigidBody(body_id=0)
|
||||
np.testing.assert_array_equal(body.position, np.zeros(3))
|
||||
np.testing.assert_array_equal(body.orientation, np.eye(3))
|
||||
assert body.local_anchors == {}
|
||||
|
||||
def test_custom_position(self) -> None:
|
||||
pos = np.array([1.0, 2.0, 3.0])
|
||||
body = RigidBody(body_id=7, position=pos)
|
||||
np.testing.assert_array_equal(body.position, pos)
|
||||
assert body.body_id == 7
|
||||
|
||||
def test_local_anchors_mutable(self) -> None:
|
||||
body = RigidBody(body_id=0)
|
||||
body.local_anchors["top"] = np.array([0.0, 0.0, 1.0])
|
||||
assert "top" in body.local_anchors
|
||||
|
||||
def test_default_factory_isolation(self) -> None:
|
||||
"""Each instance gets its own default containers."""
|
||||
b1 = RigidBody(body_id=0)
|
||||
b2 = RigidBody(body_id=1)
|
||||
b1.local_anchors["x"] = np.zeros(3)
|
||||
assert "x" not in b2.local_anchors
|
||||
|
||||
|
||||
class TestJoint:
|
||||
"""Joint dataclass defaults and construction."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE)
|
||||
np.testing.assert_array_equal(j.anchor_a, np.zeros(3))
|
||||
np.testing.assert_array_equal(j.anchor_b, np.zeros(3))
|
||||
np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0]))
|
||||
assert j.pitch == 0.0
|
||||
|
||||
def test_full_construction(self) -> None:
|
||||
j = Joint(
|
||||
joint_id=5,
|
||||
body_a=2,
|
||||
body_b=3,
|
||||
joint_type=JointType.SCREW,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([2.0, 0.0, 0.0]),
|
||||
axis=np.array([1.0, 0.0, 0.0]),
|
||||
pitch=0.5,
|
||||
)
|
||||
assert j.joint_id == 5
|
||||
assert j.joint_type is JointType.SCREW
|
||||
assert j.pitch == 0.5
|
||||
|
||||
|
||||
class TestPebbleState:
|
||||
"""PebbleState dataclass defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
s = PebbleState()
|
||||
assert s.free_pebbles == {}
|
||||
assert s.directed_edges == {}
|
||||
assert s.independent_edges == set()
|
||||
assert s.redundant_edges == set()
|
||||
assert s.incoming == {}
|
||||
assert s.outgoing == {}
|
||||
|
||||
def test_default_factory_isolation(self) -> None:
|
||||
s1 = PebbleState()
|
||||
s2 = PebbleState()
|
||||
s1.free_pebbles[0] = 6
|
||||
assert 0 not in s2.free_pebbles
|
||||
|
||||
|
||||
class TestConstraintAnalysis:
|
||||
"""ConstraintAnalysis dataclass construction."""
|
||||
|
||||
def test_construction(self) -> None:
|
||||
ca = ConstraintAnalysis(
|
||||
combinatorial_dof=6,
|
||||
combinatorial_internal_dof=0,
|
||||
combinatorial_redundant=0,
|
||||
combinatorial_classification="well-constrained",
|
||||
per_edge_results=[],
|
||||
jacobian_rank=6,
|
||||
jacobian_nullity=0,
|
||||
jacobian_internal_dof=0,
|
||||
numerically_dependent=[],
|
||||
geometric_degeneracies=0,
|
||||
is_rigid=True,
|
||||
is_minimally_rigid=True,
|
||||
)
|
||||
assert ca.is_rigid is True
|
||||
assert ca.is_minimally_rigid is True
|
||||
assert ca.combinatorial_classification == "well-constrained"
|
||||
|
||||
def test_per_edge_results_typing(self) -> None:
|
||||
"""per_edge_results accepts list[dict[str, Any]]."""
|
||||
ca = ConstraintAnalysis(
|
||||
combinatorial_dof=7,
|
||||
combinatorial_internal_dof=1,
|
||||
combinatorial_redundant=0,
|
||||
combinatorial_classification="underconstrained",
|
||||
per_edge_results=[{"edge_id": 0, "independent": True}],
|
||||
jacobian_rank=5,
|
||||
jacobian_nullity=1,
|
||||
jacobian_internal_dof=1,
|
||||
numerically_dependent=[],
|
||||
geometric_degeneracies=0,
|
||||
is_rigid=False,
|
||||
is_minimally_rigid=False,
|
||||
)
|
||||
assert len(ca.per_edge_results) == 1
|
||||
assert ca.per_edge_results[0]["edge_id"] == 0
|
||||
Reference in New Issue
Block a user