Compare commits
13 Commits
e32c9cd793
...
5d1988b513
| Author | SHA1 | Date | |
|---|---|---|---|
| 5d1988b513 | |||
| f29060491e | |||
| 8a49f8ef40 | |||
| 78289494e2 | |||
| 0b5813b5a9 | |||
| dc742bfc82 | |||
| 831a10cdb4 | |||
| 9a31df4988 | |||
| 455b6318d9 | |||
| 35d4ef736f | |||
| 1b6135129e | |||
| 363b49281b | |||
| f61d005400 |
65
.gitea/workflows/ci.yaml
Normal file
65
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install ruff mypy
|
||||||
|
pip install -e ".[dev]" || pip install ruff mypy numpy
|
||||||
|
|
||||||
|
- name: Ruff check
|
||||||
|
run: ruff check solver/ freecad/ tests/ scripts/
|
||||||
|
|
||||||
|
- name: Ruff format check
|
||||||
|
run: ruff format --check solver/ freecad/ tests/ scripts/
|
||||||
|
|
||||||
|
type-check:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install mypy numpy
|
||||||
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install torch-geometric
|
||||||
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
|
- name: Mypy
|
||||||
|
run: mypy solver/ freecad/
|
||||||
|
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: |
|
||||||
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
pip install torch-geometric
|
||||||
|
pip install -e ".[train,dev]"
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest tests/ freecad/tests/ -v --tb=short
|
||||||
77
.gitignore
vendored
77
.gitignore
vendored
@@ -1,44 +1,83 @@
|
|||||||
# Prerequisites
|
# C++ compiled objects
|
||||||
*.d
|
*.d
|
||||||
|
|
||||||
# Compiled Object files
|
|
||||||
*.slo
|
*.slo
|
||||||
*.lo
|
*.lo
|
||||||
*.o
|
*.o
|
||||||
*.obj
|
*.obj
|
||||||
|
|
||||||
# Precompiled Headers
|
|
||||||
*.gch
|
*.gch
|
||||||
*.pch
|
*.pch
|
||||||
|
|
||||||
# Compiled Dynamic libraries
|
# C++ libraries
|
||||||
*.so
|
*.so
|
||||||
*.dylib
|
*.dylib
|
||||||
*.dll
|
*.dll
|
||||||
|
|
||||||
# Fortran module files
|
|
||||||
*.mod
|
|
||||||
*.smod
|
|
||||||
|
|
||||||
# Compiled Static libraries
|
|
||||||
*.lai
|
*.lai
|
||||||
*.la
|
*.la
|
||||||
*.a
|
*.a
|
||||||
*.lib
|
*.lib
|
||||||
|
|
||||||
# Executables
|
# C++ executables
|
||||||
*.exe
|
*.exe
|
||||||
*.out
|
*.out
|
||||||
*.app
|
*.app
|
||||||
|
|
||||||
.vs
|
# C++ build
|
||||||
|
build/
|
||||||
|
cmake-build-debug/
|
||||||
|
.vs/
|
||||||
x64/
|
x64/
|
||||||
|
temp/
|
||||||
|
|
||||||
|
# OndselSolver test artifacts
|
||||||
*.bak
|
*.bak
|
||||||
assembly.asmt
|
assembly.asmt
|
||||||
|
|
||||||
build
|
|
||||||
cmake-build-debug
|
|
||||||
.idea
|
|
||||||
temp/
|
|
||||||
/testapp/draggingBackhoe.log
|
/testapp/draggingBackhoe.log
|
||||||
/testapp/runPreDragBackhoe.asmt
|
/testapp/runPreDragBackhoe.asmt
|
||||||
|
|
||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
*.egg
|
||||||
|
|
||||||
|
# Virtual environments
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
*.swp
|
||||||
|
*.swo
|
||||||
|
*~
|
||||||
|
|
||||||
|
# mypy / ruff / pytest
|
||||||
|
.mypy_cache/
|
||||||
|
.ruff_cache/
|
||||||
|
.pytest_cache/
|
||||||
|
|
||||||
|
# Data (large files tracked separately)
|
||||||
|
data/synthetic/*.pt
|
||||||
|
data/fusion360/*.json
|
||||||
|
data/fusion360/*.step
|
||||||
|
data/processed/*.pt
|
||||||
|
!data/**/.gitkeep
|
||||||
|
|
||||||
|
# Model checkpoints
|
||||||
|
*.ckpt
|
||||||
|
*.pth
|
||||||
|
*.onnx
|
||||||
|
*.torchscript
|
||||||
|
|
||||||
|
# Experiment tracking
|
||||||
|
wandb/
|
||||||
|
runs/
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Environment
|
||||||
|
.env
|
||||||
|
|||||||
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
repos:
|
||||||
|
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||||
|
rev: v0.3.4
|
||||||
|
hooks:
|
||||||
|
- id: ruff
|
||||||
|
args: [--fix]
|
||||||
|
- id: ruff-format
|
||||||
|
|
||||||
|
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||||
|
rev: v1.8.0
|
||||||
|
hooks:
|
||||||
|
- id: mypy
|
||||||
|
additional_dependencies:
|
||||||
|
- torch>=2.2
|
||||||
|
- numpy>=1.26
|
||||||
|
args: [--ignore-missing-imports]
|
||||||
|
|
||||||
|
- repo: https://github.com/compilerla/conventional-pre-commit
|
||||||
|
rev: v3.1.0
|
||||||
|
hooks:
|
||||||
|
- id: conventional-pre-commit
|
||||||
|
stages: [commit-msg]
|
||||||
|
args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert]
|
||||||
61
Dockerfile
Normal file
61
Dockerfile
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base
|
||||||
|
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
# System deps
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
python3.11 python3.11-venv python3.11-dev python3-pip \
|
||||||
|
git wget curl \
|
||||||
|
# FreeCAD headless deps
|
||||||
|
freecad \
|
||||||
|
libgl1-mesa-glx libglib2.0-0 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
|
||||||
|
|
||||||
|
# Create venv
|
||||||
|
RUN python -m venv /opt/venv
|
||||||
|
ENV PATH="/opt/venv/bin:$PATH"
|
||||||
|
|
||||||
|
# Install PyTorch with CUDA
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||||
|
|
||||||
|
# Install PyG
|
||||||
|
RUN pip install --no-cache-dir \
|
||||||
|
torch-geometric \
|
||||||
|
pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
|
||||||
|
-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
# Install project
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir -e ".[train,dev]" || true
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||||
|
|
||||||
|
# -------------------------------------------------------------------
|
||||||
|
FROM base AS cpu
|
||||||
|
|
||||||
|
# CPU-only variant (for CI and non-GPU environments)
|
||||||
|
FROM python:3.11-slim AS cpu-only
|
||||||
|
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
|
||||||
|
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||||
|
git freecad libgl1-mesa-glx libglib2.0-0 \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
WORKDIR /workspace
|
||||||
|
|
||||||
|
COPY pyproject.toml .
|
||||||
|
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||||
|
RUN pip install --no-cache-dir torch-geometric
|
||||||
|
|
||||||
|
COPY . .
|
||||||
|
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||||
|
|
||||||
|
CMD ["pytest", "tests/", "-v"]
|
||||||
48
Makefile
Normal file
48
Makefile
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
.PHONY: train test lint data-gen export format type-check install dev clean help
|
||||||
|
|
||||||
|
PYTHON ?= python
|
||||||
|
PYTEST ?= pytest
|
||||||
|
RUFF ?= ruff
|
||||||
|
MYPY ?= mypy
|
||||||
|
|
||||||
|
help: ## Show this help
|
||||||
|
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
|
||||||
|
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||||
|
|
||||||
|
install: ## Install core dependencies
|
||||||
|
pip install -e .
|
||||||
|
|
||||||
|
dev: ## Install all dependencies including dev tools
|
||||||
|
pip install -e ".[train,dev]"
|
||||||
|
pre-commit install
|
||||||
|
pre-commit install --hook-type commit-msg
|
||||||
|
|
||||||
|
train: ## Run training (pass CONFIG=path/to/config.yaml)
|
||||||
|
$(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG))
|
||||||
|
|
||||||
|
test: ## Run test suite
|
||||||
|
$(PYTEST) tests/ freecad/tests/ -v --tb=short
|
||||||
|
|
||||||
|
lint: ## Run ruff linter
|
||||||
|
$(RUFF) check solver/ freecad/ tests/ scripts/
|
||||||
|
|
||||||
|
format: ## Format code with ruff
|
||||||
|
$(RUFF) format solver/ freecad/ tests/ scripts/
|
||||||
|
$(RUFF) check --fix solver/ freecad/ tests/ scripts/
|
||||||
|
|
||||||
|
type-check: ## Run mypy type checker
|
||||||
|
$(MYPY) solver/ freecad/
|
||||||
|
|
||||||
|
data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml)
|
||||||
|
$(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG))
|
||||||
|
|
||||||
|
export: ## Export trained model for deployment
|
||||||
|
$(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL))
|
||||||
|
|
||||||
|
clean: ## Remove build artifacts and caches
|
||||||
|
rm -rf build/ dist/ *.egg-info/
|
||||||
|
rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/
|
||||||
|
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
|
||||||
|
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||||
|
|
||||||
|
check: lint type-check test ## Run all checks (lint, type-check, test)
|
||||||
94
README.md
94
README.md
@@ -1,7 +1,91 @@
|
|||||||
# MbDCode
|
# Kindred Solver
|
||||||
Assembly Constraints and Multibody Dynamics code
|
|
||||||
|
|
||||||
Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited)
|
Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer.
|
||||||
|
|
||||||
The MbD theory is at
|
## Components
|
||||||
https://github.com/Ondsel-Development/MbDTheory
|
|
||||||
|
### OndselSolver (C++)
|
||||||
|
|
||||||
|
Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver.
|
||||||
|
|
||||||
|
- Source: `OndselSolver/`
|
||||||
|
- Entry point: `OndselSolverMain/`
|
||||||
|
- Tests: `tests/`, `testapp/`
|
||||||
|
- Build: CMake
|
||||||
|
|
||||||
|
**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory)
|
||||||
|
|
||||||
|
#### Building
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||||
|
cmake --build build
|
||||||
|
```
|
||||||
|
|
||||||
|
### ML Solver Layer (Python)
|
||||||
|
|
||||||
|
Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns.
|
||||||
|
|
||||||
|
- Core library: `solver/`
|
||||||
|
- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling)
|
||||||
|
- Model architectures: `solver/models/` (GIN, GAT, NNConv)
|
||||||
|
- Training: `solver/training/`
|
||||||
|
- Inference: `solver/inference/`
|
||||||
|
- FreeCAD integration: `freecad/`
|
||||||
|
- Configuration: `configs/` (Hydra)
|
||||||
|
|
||||||
|
#### Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -e ".[train,dev]"
|
||||||
|
pre-commit install
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Usage
|
||||||
|
|
||||||
|
```bash
|
||||||
|
make help # show all targets
|
||||||
|
make dev # install all deps + pre-commit hooks
|
||||||
|
make test # run tests
|
||||||
|
make lint # run ruff linter
|
||||||
|
make check # lint + type-check + test
|
||||||
|
make data-gen # generate synthetic data
|
||||||
|
make train # run training
|
||||||
|
make export # export model
|
||||||
|
```
|
||||||
|
|
||||||
|
Docker is also supported:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up train # GPU training
|
||||||
|
docker compose up test # run tests
|
||||||
|
docker compose up data-gen # generate synthetic data
|
||||||
|
```
|
||||||
|
|
||||||
|
## Repository structure
|
||||||
|
|
||||||
|
```
|
||||||
|
kindred-solver/
|
||||||
|
├── OndselSolver/ # C++ numerical solver library
|
||||||
|
├── OndselSolverMain/ # C++ solver CLI entry point
|
||||||
|
├── tests/ # C++ unit tests + Python tests
|
||||||
|
├── testapp/ # C++ test application
|
||||||
|
├── solver/ # Python ML solver library
|
||||||
|
│ ├── datagen/ # Synthetic data generation (pebble game)
|
||||||
|
│ ├── datasets/ # PyG dataset adapters
|
||||||
|
│ ├── models/ # GNN architectures
|
||||||
|
│ ├── training/ # Training loops
|
||||||
|
│ ├── evaluation/ # Metrics and visualization
|
||||||
|
│ └── inference/ # Runtime prediction API
|
||||||
|
├── freecad/ # FreeCAD workbench integration
|
||||||
|
├── configs/ # Hydra configs (dataset, model, training, export)
|
||||||
|
├── scripts/ # CLI utilities
|
||||||
|
├── data/ # Datasets (not committed)
|
||||||
|
├── export/ # Model packaging
|
||||||
|
└── docs/ # Documentation
|
||||||
|
```
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE))
|
||||||
|
ML Solver Layer: Apache-2.0
|
||||||
|
|||||||
12
configs/dataset/fusion360.yaml
Normal file
12
configs/dataset/fusion360.yaml
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
# Fusion 360 Gallery dataset config
|
||||||
|
name: fusion360
|
||||||
|
data_dir: data/fusion360
|
||||||
|
output_dir: data/processed
|
||||||
|
|
||||||
|
splits:
|
||||||
|
train: 0.8
|
||||||
|
val: 0.1
|
||||||
|
test: 0.1
|
||||||
|
|
||||||
|
stratify_by: complexity
|
||||||
|
seed: 42
|
||||||
26
configs/dataset/synthetic.yaml
Normal file
26
configs/dataset/synthetic.yaml
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
# Synthetic dataset generation config
|
||||||
|
name: synthetic
|
||||||
|
num_assemblies: 100000
|
||||||
|
output_dir: data/synthetic
|
||||||
|
shard_size: 1000
|
||||||
|
|
||||||
|
complexity_distribution:
|
||||||
|
simple: 0.4 # 2-5 bodies
|
||||||
|
medium: 0.4 # 6-15 bodies
|
||||||
|
complex: 0.2 # 16-50 bodies
|
||||||
|
|
||||||
|
body_count:
|
||||||
|
min: 2
|
||||||
|
max: 50
|
||||||
|
|
||||||
|
templates:
|
||||||
|
- chain
|
||||||
|
- tree
|
||||||
|
- loop
|
||||||
|
- star
|
||||||
|
- mixed
|
||||||
|
|
||||||
|
grounded_ratio: 0.5
|
||||||
|
seed: 42
|
||||||
|
num_workers: 4
|
||||||
|
checkpoint_every: 5
|
||||||
25
configs/export/production.yaml
Normal file
25
configs/export/production.yaml
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
# Production model export config
|
||||||
|
model_checkpoint: checkpoints/finetune/best_val_loss.ckpt
|
||||||
|
output_dir: export/
|
||||||
|
|
||||||
|
formats:
|
||||||
|
onnx:
|
||||||
|
enabled: true
|
||||||
|
opset_version: 17
|
||||||
|
dynamic_axes: true
|
||||||
|
torchscript:
|
||||||
|
enabled: true
|
||||||
|
|
||||||
|
model_card:
|
||||||
|
version: "0.1.0"
|
||||||
|
architecture: baseline
|
||||||
|
training_data:
|
||||||
|
- synthetic_100k
|
||||||
|
- fusion360_gallery
|
||||||
|
|
||||||
|
size_budget_mb: 50
|
||||||
|
|
||||||
|
inference:
|
||||||
|
device: cpu
|
||||||
|
batch_size: 1
|
||||||
|
confidence_threshold: 0.8
|
||||||
24
configs/model/baseline.yaml
Normal file
24
configs/model/baseline.yaml
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
# Baseline GIN model config
|
||||||
|
name: baseline
|
||||||
|
architecture: gin
|
||||||
|
|
||||||
|
encoder:
|
||||||
|
num_layers: 3
|
||||||
|
hidden_dim: 128
|
||||||
|
dropout: 0.1
|
||||||
|
|
||||||
|
node_features_dim: 22
|
||||||
|
edge_features_dim: 22
|
||||||
|
|
||||||
|
heads:
|
||||||
|
edge_classification:
|
||||||
|
enabled: true
|
||||||
|
hidden_dim: 64
|
||||||
|
graph_classification:
|
||||||
|
enabled: true
|
||||||
|
num_classes: 4 # rigid, under, over, mixed
|
||||||
|
joint_type:
|
||||||
|
enabled: true
|
||||||
|
num_classes: 12
|
||||||
|
dof_regression:
|
||||||
|
enabled: true
|
||||||
28
configs/model/gat.yaml
Normal file
28
configs/model/gat.yaml
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
# Advanced GAT model config
|
||||||
|
name: gat_solver
|
||||||
|
architecture: gat
|
||||||
|
|
||||||
|
encoder:
|
||||||
|
num_layers: 4
|
||||||
|
hidden_dim: 256
|
||||||
|
num_heads: 8
|
||||||
|
dropout: 0.1
|
||||||
|
residual: true
|
||||||
|
|
||||||
|
node_features_dim: 22
|
||||||
|
edge_features_dim: 22
|
||||||
|
|
||||||
|
heads:
|
||||||
|
edge_classification:
|
||||||
|
enabled: true
|
||||||
|
hidden_dim: 128
|
||||||
|
graph_classification:
|
||||||
|
enabled: true
|
||||||
|
num_classes: 4
|
||||||
|
joint_type:
|
||||||
|
enabled: true
|
||||||
|
num_classes: 12
|
||||||
|
dof_regression:
|
||||||
|
enabled: true
|
||||||
|
dof_tracking:
|
||||||
|
enabled: true
|
||||||
45
configs/training/finetune.yaml
Normal file
45
configs/training/finetune.yaml
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
# Fine-tuning on real data config
|
||||||
|
phase: finetune
|
||||||
|
|
||||||
|
dataset: fusion360
|
||||||
|
model: baseline
|
||||||
|
|
||||||
|
pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
name: adamw
|
||||||
|
lr: 1e-5
|
||||||
|
weight_decay: 1e-4
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
name: cosine_annealing
|
||||||
|
T_max: 50
|
||||||
|
eta_min: 1e-7
|
||||||
|
|
||||||
|
training:
|
||||||
|
epochs: 50
|
||||||
|
batch_size: 32
|
||||||
|
gradient_clip: 1.0
|
||||||
|
early_stopping_patience: 10
|
||||||
|
amp: true
|
||||||
|
freeze_encoder: false # set true for frozen encoder experiment
|
||||||
|
|
||||||
|
loss:
|
||||||
|
edge_weight: 1.0
|
||||||
|
graph_weight: 0.5
|
||||||
|
joint_type_weight: 0.3
|
||||||
|
dof_weight: 0.2
|
||||||
|
redundant_penalty: 2.0
|
||||||
|
|
||||||
|
checkpointing:
|
||||||
|
save_best_val_loss: true
|
||||||
|
save_best_val_accuracy: true
|
||||||
|
save_every_n_epochs: 5
|
||||||
|
checkpoint_dir: checkpoints/finetune
|
||||||
|
|
||||||
|
logging:
|
||||||
|
backend: wandb
|
||||||
|
project: kindred-solver
|
||||||
|
log_every_n_steps: 20
|
||||||
|
|
||||||
|
seed: 42
|
||||||
42
configs/training/pretrain.yaml
Normal file
42
configs/training/pretrain.yaml
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
# Synthetic pre-training config
|
||||||
|
phase: pretrain
|
||||||
|
|
||||||
|
dataset: synthetic
|
||||||
|
model: baseline
|
||||||
|
|
||||||
|
optimizer:
|
||||||
|
name: adamw
|
||||||
|
lr: 1e-3
|
||||||
|
weight_decay: 1e-4
|
||||||
|
|
||||||
|
scheduler:
|
||||||
|
name: cosine_annealing
|
||||||
|
T_max: 100
|
||||||
|
eta_min: 1e-6
|
||||||
|
|
||||||
|
training:
|
||||||
|
epochs: 100
|
||||||
|
batch_size: 64
|
||||||
|
gradient_clip: 1.0
|
||||||
|
early_stopping_patience: 10
|
||||||
|
amp: true
|
||||||
|
|
||||||
|
loss:
|
||||||
|
edge_weight: 1.0
|
||||||
|
graph_weight: 0.5
|
||||||
|
joint_type_weight: 0.3
|
||||||
|
dof_weight: 0.2
|
||||||
|
redundant_penalty: 2.0 # safety loss multiplier
|
||||||
|
|
||||||
|
checkpointing:
|
||||||
|
save_best_val_loss: true
|
||||||
|
save_best_val_accuracy: true
|
||||||
|
save_every_n_epochs: 10
|
||||||
|
checkpoint_dir: checkpoints/pretrain
|
||||||
|
|
||||||
|
logging:
|
||||||
|
backend: wandb # or tensorboard
|
||||||
|
project: kindred-solver
|
||||||
|
log_every_n_steps: 50
|
||||||
|
|
||||||
|
seed: 42
|
||||||
0
data/fusion360/.gitkeep
Normal file
0
data/fusion360/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
39
docker-compose.yml
Normal file
39
docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
services:
|
||||||
|
train:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
target: base
|
||||||
|
volumes:
|
||||||
|
- .:/workspace
|
||||||
|
- ./data:/workspace/data
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: all
|
||||||
|
capabilities: [gpu]
|
||||||
|
command: make train
|
||||||
|
environment:
|
||||||
|
- CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
|
||||||
|
- WANDB_API_KEY=${WANDB_API_KEY:-}
|
||||||
|
|
||||||
|
test:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
target: cpu-only
|
||||||
|
volumes:
|
||||||
|
- .:/workspace
|
||||||
|
command: make check
|
||||||
|
|
||||||
|
data-gen:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
target: base
|
||||||
|
volumes:
|
||||||
|
- .:/workspace
|
||||||
|
- ./data:/workspace/data
|
||||||
|
command: make data-gen
|
||||||
0
docs/.gitkeep
Normal file
0
docs/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
freecad/__init__.py
Normal file
0
freecad/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
97
pyproject.toml
Normal file
97
pyproject.toml
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
[build-system]
|
||||||
|
requires = ["hatchling"]
|
||||||
|
build-backend = "hatchling.build"
|
||||||
|
|
||||||
|
[project]
|
||||||
|
name = "kindred-solver"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Assembly constraint prediction via GNN for Kindred Create"
|
||||||
|
readme = "README.md"
|
||||||
|
license = "Apache-2.0"
|
||||||
|
requires-python = ">=3.11"
|
||||||
|
authors = [
|
||||||
|
{ name = "Kindred Systems" },
|
||||||
|
]
|
||||||
|
classifiers = [
|
||||||
|
"Development Status :: 3 - Alpha",
|
||||||
|
"License :: OSI Approved :: Apache Software License",
|
||||||
|
"Programming Language :: Python :: 3",
|
||||||
|
"Programming Language :: Python :: 3.11",
|
||||||
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Topic :: Scientific/Engineering",
|
||||||
|
]
|
||||||
|
dependencies = [
|
||||||
|
"torch>=2.2",
|
||||||
|
"torch-geometric>=2.5",
|
||||||
|
"numpy>=1.26",
|
||||||
|
"scipy>=1.12",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
train = [
|
||||||
|
"wandb>=0.16",
|
||||||
|
"tensorboard>=2.16",
|
||||||
|
"hydra-core>=1.3",
|
||||||
|
"omegaconf>=2.3",
|
||||||
|
"matplotlib>=3.8",
|
||||||
|
"networkx>=3.2",
|
||||||
|
]
|
||||||
|
freecad = [
|
||||||
|
"pyside6>=6.6",
|
||||||
|
]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0",
|
||||||
|
"pytest-cov>=4.1",
|
||||||
|
"ruff>=0.3",
|
||||||
|
"mypy>=1.8",
|
||||||
|
"pre-commit>=3.6",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.urls]
|
||||||
|
Repository = "https://git.kindred-systems.com/kindred/solver"
|
||||||
|
|
||||||
|
[tool.hatch.build.targets.wheel]
|
||||||
|
packages = ["solver", "freecad"]
|
||||||
|
|
||||||
|
[tool.ruff]
|
||||||
|
target-version = "py311"
|
||||||
|
line-length = 100
|
||||||
|
|
||||||
|
[tool.ruff.lint]
|
||||||
|
select = [
|
||||||
|
"E", # pycodestyle errors
|
||||||
|
"W", # pycodestyle warnings
|
||||||
|
"F", # pyflakes
|
||||||
|
"I", # isort
|
||||||
|
"N", # pep8-naming
|
||||||
|
"UP", # pyupgrade
|
||||||
|
"B", # flake8-bugbear
|
||||||
|
"SIM", # flake8-simplify
|
||||||
|
"TCH", # flake8-type-checking
|
||||||
|
"RUF", # ruff-specific
|
||||||
|
]
|
||||||
|
|
||||||
|
[tool.ruff.lint.isort]
|
||||||
|
known-first-party = ["solver", "freecad"]
|
||||||
|
|
||||||
|
[tool.mypy]
|
||||||
|
python_version = "3.11"
|
||||||
|
warn_return_any = true
|
||||||
|
warn_unused_configs = true
|
||||||
|
disallow_untyped_defs = true
|
||||||
|
check_untyped_defs = true
|
||||||
|
|
||||||
|
[[tool.mypy.overrides]]
|
||||||
|
module = [
|
||||||
|
"torch.*",
|
||||||
|
"torch_geometric.*",
|
||||||
|
"scipy.*",
|
||||||
|
"wandb.*",
|
||||||
|
"hydra.*",
|
||||||
|
"omegaconf.*",
|
||||||
|
]
|
||||||
|
ignore_missing_imports = true
|
||||||
|
|
||||||
|
[tool.pytest.ini_options]
|
||||||
|
testpaths = ["tests", "freecad/tests"]
|
||||||
|
addopts = "-v --tb=short"
|
||||||
115
scripts/generate_synthetic.py
Normal file
115
scripts/generate_synthetic.py
Normal file
@@ -0,0 +1,115 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
"""Generate synthetic assembly dataset for kindred-solver training.
|
||||||
|
|
||||||
|
Usage (argparse fallback — always available)::
|
||||||
|
|
||||||
|
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
||||||
|
|
||||||
|
Usage (Hydra — when hydra-core is installed)::
|
||||||
|
|
||||||
|
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
|
||||||
|
def _try_hydra_main() -> bool:
|
||||||
|
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
||||||
|
try:
|
||||||
|
import hydra # type: ignore[import-untyped]
|
||||||
|
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
||||||
|
except ImportError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@hydra.main(
|
||||||
|
config_path="../configs/dataset",
|
||||||
|
config_name="synthetic",
|
||||||
|
version_base=None,
|
||||||
|
)
|
||||||
|
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
||||||
|
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||||
|
|
||||||
|
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||||
|
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||||
|
DatasetGenerator(config).run()
|
||||||
|
|
||||||
|
_run() # type: ignore[no-untyped-call]
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def _argparse_main() -> None:
|
||||||
|
"""Fallback CLI using argparse."""
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Generate synthetic assembly dataset",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--config",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Path to YAML config file (optional)",
|
||||||
|
)
|
||||||
|
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
||||||
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
||||||
|
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
||||||
|
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
||||||
|
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
||||||
|
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
||||||
|
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||||
|
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
||||||
|
parser.add_argument(
|
||||||
|
"--checkpoint-every",
|
||||||
|
type=int,
|
||||||
|
default=None,
|
||||||
|
help="Checkpoint interval (shards)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--no-resume",
|
||||||
|
action="store_true",
|
||||||
|
help="Do not resume from existing checkpoints",
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
from solver.datagen.dataset import (
|
||||||
|
DatasetConfig,
|
||||||
|
DatasetGenerator,
|
||||||
|
parse_simple_yaml,
|
||||||
|
)
|
||||||
|
|
||||||
|
config_dict: dict[str, object] = {}
|
||||||
|
if args.config:
|
||||||
|
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# CLI args override config file (only when explicitly provided)
|
||||||
|
_override_map = {
|
||||||
|
"num_assemblies": args.num_assemblies,
|
||||||
|
"output_dir": args.output_dir,
|
||||||
|
"shard_size": args.shard_size,
|
||||||
|
"body_count_min": args.body_count_min,
|
||||||
|
"body_count_max": args.body_count_max,
|
||||||
|
"grounded_ratio": args.grounded_ratio,
|
||||||
|
"seed": args.seed,
|
||||||
|
"num_workers": args.num_workers,
|
||||||
|
"checkpoint_every": args.checkpoint_every,
|
||||||
|
}
|
||||||
|
for key, val in _override_map.items():
|
||||||
|
if val is not None:
|
||||||
|
config_dict[key] = val
|
||||||
|
|
||||||
|
if args.no_resume:
|
||||||
|
config_dict["resume"] = False
|
||||||
|
|
||||||
|
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||||
|
DatasetGenerator(config).run()
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
"""Entry point: try Hydra first, fall back to argparse."""
|
||||||
|
if not _try_hydra_main():
|
||||||
|
_argparse_main()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
0
solver/__init__.py
Normal file
0
solver/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Data generation utilities for assembly constraint training data."""
|
||||||
|
|
||||||
|
from solver.datagen.analysis import analyze_assembly
|
||||||
|
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||||
|
from solver.datagen.generator import (
|
||||||
|
COMPLEXITY_RANGES,
|
||||||
|
AxisStrategy,
|
||||||
|
SyntheticAssemblyGenerator,
|
||||||
|
)
|
||||||
|
from solver.datagen.jacobian import JacobianVerifier
|
||||||
|
from solver.datagen.labeling import AssemblyLabels, label_assembly
|
||||||
|
from solver.datagen.pebble_game import PebbleGame3D
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
PebbleState,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"COMPLEXITY_RANGES",
|
||||||
|
"AssemblyLabels",
|
||||||
|
"AxisStrategy",
|
||||||
|
"ConstraintAnalysis",
|
||||||
|
"DatasetConfig",
|
||||||
|
"DatasetGenerator",
|
||||||
|
"JacobianVerifier",
|
||||||
|
"Joint",
|
||||||
|
"JointType",
|
||||||
|
"PebbleGame3D",
|
||||||
|
"PebbleState",
|
||||||
|
"RigidBody",
|
||||||
|
"SyntheticAssemblyGenerator",
|
||||||
|
"analyze_assembly",
|
||||||
|
"label_assembly",
|
||||||
|
]
|
||||||
140
solver/datagen/analysis.py
Normal file
140
solver/datagen/analysis.py
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
"""Combined pebble game + Jacobian verification analysis.
|
||||||
|
|
||||||
|
Provides :func:`analyze_assembly`, the main entry point for full rigidity
|
||||||
|
analysis of an assembly using both combinatorial and numerical methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from solver.datagen.jacobian import JacobianVerifier
|
||||||
|
from solver.datagen.pebble_game import PebbleGame3D
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = ["analyze_assembly"]
|
||||||
|
|
||||||
|
_GROUND_ID = -1
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_assembly(
|
||||||
|
bodies: list[RigidBody],
|
||||||
|
joints: list[Joint],
|
||||||
|
ground_body: int | None = None,
|
||||||
|
) -> ConstraintAnalysis:
|
||||||
|
"""Full rigidity analysis of an assembly using both methods.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bodies: List of rigid bodies in the assembly.
|
||||||
|
joints: List of joints connecting bodies.
|
||||||
|
ground_body: If set, this body is fixed (adds 6 implicit constraints).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
ConstraintAnalysis with combinatorial and numerical results.
|
||||||
|
"""
|
||||||
|
# --- Pebble Game ---
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
all_edge_results = []
|
||||||
|
|
||||||
|
# Add a virtual ground body (id=-1) if grounding is requested.
|
||||||
|
# Grounding body X means adding a fixed joint between X and
|
||||||
|
# the virtual ground. This properly lets the pebble game account
|
||||||
|
# for the 6 removed DOF without breaking invariants.
|
||||||
|
if ground_body is not None:
|
||||||
|
pg.add_body(_GROUND_ID)
|
||||||
|
|
||||||
|
for body in bodies:
|
||||||
|
pg.add_body(body.body_id)
|
||||||
|
|
||||||
|
if ground_body is not None:
|
||||||
|
ground_joint = Joint(
|
||||||
|
joint_id=-1,
|
||||||
|
body_a=ground_body,
|
||||||
|
body_b=_GROUND_ID,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||||
|
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||||
|
)
|
||||||
|
pg.add_joint(ground_joint)
|
||||||
|
# Don't include ground joint edges in the output labels
|
||||||
|
# (they're infrastructure, not user constraints)
|
||||||
|
|
||||||
|
for joint in joints:
|
||||||
|
results = pg.add_joint(joint)
|
||||||
|
all_edge_results.extend(results)
|
||||||
|
|
||||||
|
combinatorial_independent = len(pg.state.independent_edges)
|
||||||
|
grounded = ground_body is not None
|
||||||
|
|
||||||
|
# The virtual ground body contributes 6 pebbles to the total.
|
||||||
|
# Subtract those from the reported DOF for user-facing numbers.
|
||||||
|
raw_dof = pg.get_dof()
|
||||||
|
ground_offset = 6 if grounded else 0
|
||||||
|
effective_dof = raw_dof - ground_offset
|
||||||
|
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||||
|
|
||||||
|
# Classify based on effective (adjusted) DOF, not raw pebble game output,
|
||||||
|
# because the virtual ground body skews the raw numbers.
|
||||||
|
redundant = pg.get_redundant_count()
|
||||||
|
if redundant > 0 and effective_internal_dof > 0:
|
||||||
|
combinatorial_classification = "mixed"
|
||||||
|
elif redundant > 0:
|
||||||
|
combinatorial_classification = "overconstrained"
|
||||||
|
elif effective_internal_dof > 0:
|
||||||
|
combinatorial_classification = "underconstrained"
|
||||||
|
else:
|
||||||
|
combinatorial_classification = "well-constrained"
|
||||||
|
|
||||||
|
# --- Jacobian Verification ---
|
||||||
|
verifier = JacobianVerifier(bodies)
|
||||||
|
|
||||||
|
for joint in joints:
|
||||||
|
verifier.add_joint_constraints(joint)
|
||||||
|
|
||||||
|
# If grounded, remove the ground body's columns (fix its DOF)
|
||||||
|
j = verifier.get_jacobian()
|
||||||
|
if ground_body is not None and j.size > 0:
|
||||||
|
idx = verifier.body_index[ground_body]
|
||||||
|
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||||
|
j = np.delete(j, cols_to_remove, axis=1)
|
||||||
|
|
||||||
|
if j.size > 0:
|
||||||
|
sv = np.linalg.svd(j, compute_uv=False)
|
||||||
|
jacobian_rank = int(np.sum(sv > 1e-8))
|
||||||
|
else:
|
||||||
|
jacobian_rank = 0
|
||||||
|
|
||||||
|
n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies)
|
||||||
|
jacobian_nullity = n_cols - jacobian_rank
|
||||||
|
|
||||||
|
dependent = verifier.find_dependencies()
|
||||||
|
|
||||||
|
# Adjust for ground
|
||||||
|
trivial_dof = 0 if ground_body is not None else 6
|
||||||
|
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||||
|
|
||||||
|
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||||
|
|
||||||
|
# Rigidity: numerically rigid if nullity == trivial DOF
|
||||||
|
is_rigid = jacobian_nullity <= trivial_dof
|
||||||
|
is_minimally_rigid = is_rigid and len(dependent) == 0
|
||||||
|
|
||||||
|
return ConstraintAnalysis(
|
||||||
|
combinatorial_dof=effective_dof,
|
||||||
|
combinatorial_internal_dof=effective_internal_dof,
|
||||||
|
combinatorial_redundant=pg.get_redundant_count(),
|
||||||
|
combinatorial_classification=combinatorial_classification,
|
||||||
|
per_edge_results=all_edge_results,
|
||||||
|
jacobian_rank=jacobian_rank,
|
||||||
|
jacobian_nullity=jacobian_nullity,
|
||||||
|
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||||
|
numerically_dependent=dependent,
|
||||||
|
geometric_degeneracies=geometric_degeneracies,
|
||||||
|
is_rigid=is_rigid,
|
||||||
|
is_minimally_rigid=is_minimally_rigid,
|
||||||
|
)
|
||||||
624
solver/datagen/dataset.py
Normal file
624
solver/datagen/dataset.py
Normal file
@@ -0,0 +1,624 @@
|
|||||||
|
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
|
||||||
|
|
||||||
|
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
|
||||||
|
for parallel generation of synthetic assembly training data.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import math
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DatasetConfig",
|
||||||
|
"DatasetGenerator",
|
||||||
|
"parse_simple_yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Configuration
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DatasetConfig:
|
||||||
|
"""Configuration for synthetic dataset generation."""
|
||||||
|
|
||||||
|
name: str = "synthetic"
|
||||||
|
num_assemblies: int = 100_000
|
||||||
|
output_dir: str = "data/synthetic"
|
||||||
|
shard_size: int = 1000
|
||||||
|
complexity_distribution: dict[str, float] = field(
|
||||||
|
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
|
||||||
|
)
|
||||||
|
body_count_min: int = 2
|
||||||
|
body_count_max: int = 50
|
||||||
|
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
|
||||||
|
grounded_ratio: float = 0.5
|
||||||
|
seed: int = 42
|
||||||
|
num_workers: int = 4
|
||||||
|
checkpoint_every: int = 5
|
||||||
|
resume: bool = True
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
|
||||||
|
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
|
||||||
|
|
||||||
|
Handles both flat keys (``body_count_min``) and nested forms
|
||||||
|
(``body_count: {min: 2, max: 50}``).
|
||||||
|
"""
|
||||||
|
kw: dict[str, Any] = {}
|
||||||
|
for key in (
|
||||||
|
"name",
|
||||||
|
"num_assemblies",
|
||||||
|
"output_dir",
|
||||||
|
"shard_size",
|
||||||
|
"grounded_ratio",
|
||||||
|
"seed",
|
||||||
|
"num_workers",
|
||||||
|
"checkpoint_every",
|
||||||
|
"resume",
|
||||||
|
):
|
||||||
|
if key in d:
|
||||||
|
kw[key] = d[key]
|
||||||
|
|
||||||
|
# Handle nested body_count dict
|
||||||
|
if "body_count" in d and isinstance(d["body_count"], dict):
|
||||||
|
bc = d["body_count"]
|
||||||
|
if "min" in bc:
|
||||||
|
kw["body_count_min"] = int(bc["min"])
|
||||||
|
if "max" in bc:
|
||||||
|
kw["body_count_max"] = int(bc["max"])
|
||||||
|
else:
|
||||||
|
if "body_count_min" in d:
|
||||||
|
kw["body_count_min"] = int(d["body_count_min"])
|
||||||
|
if "body_count_max" in d:
|
||||||
|
kw["body_count_max"] = int(d["body_count_max"])
|
||||||
|
|
||||||
|
if "complexity_distribution" in d:
|
||||||
|
cd = d["complexity_distribution"]
|
||||||
|
if isinstance(cd, dict):
|
||||||
|
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
|
||||||
|
if "templates" in d and isinstance(d["templates"], list):
|
||||||
|
kw["templates"] = [str(t) for t in d["templates"]]
|
||||||
|
|
||||||
|
return cls(**kw)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shard specification / result
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardSpec:
|
||||||
|
"""Specification for generating a single shard."""
|
||||||
|
|
||||||
|
shard_id: int
|
||||||
|
start_example_id: int
|
||||||
|
count: int
|
||||||
|
seed: int
|
||||||
|
complexity_distribution: dict[str, float]
|
||||||
|
body_count_min: int
|
||||||
|
body_count_max: int
|
||||||
|
grounded_ratio: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShardResult:
|
||||||
|
"""Result returned from a shard worker."""
|
||||||
|
|
||||||
|
shard_id: int
|
||||||
|
num_examples: int
|
||||||
|
file_path: str
|
||||||
|
generation_time_s: float
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Seed derivation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
|
||||||
|
"""Derive a deterministic per-shard seed from the global seed."""
|
||||||
|
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
|
||||||
|
return int(h[:8], 16)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Progress display
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class _PrintProgress:
|
||||||
|
"""Fallback progress display when tqdm is unavailable."""
|
||||||
|
|
||||||
|
def __init__(self, total: int) -> None:
|
||||||
|
self.total = total
|
||||||
|
self.current = 0
|
||||||
|
self.start_time = time.monotonic()
|
||||||
|
|
||||||
|
def update(self, n: int = 1) -> None:
|
||||||
|
self.current += n
|
||||||
|
elapsed = time.monotonic() - self.start_time
|
||||||
|
rate = self.current / elapsed if elapsed > 0 else 0.0
|
||||||
|
eta = (self.total - self.current) / rate if rate > 0 else 0.0
|
||||||
|
pct = 100.0 * self.current / self.total
|
||||||
|
sys.stdout.write(
|
||||||
|
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
|
||||||
|
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
|
||||||
|
)
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
def close(self) -> None:
|
||||||
|
sys.stdout.write("\n")
|
||||||
|
sys.stdout.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_progress(total: int) -> _PrintProgress:
|
||||||
|
"""Create a progress tracker (tqdm if available, else print-based)."""
|
||||||
|
try:
|
||||||
|
from tqdm import tqdm # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
|
||||||
|
except ImportError:
|
||||||
|
return _PrintProgress(total)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shard I/O
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _save_shard(
|
||||||
|
shard_id: int,
|
||||||
|
examples: list[dict[str, Any]],
|
||||||
|
shards_dir: Path,
|
||||||
|
) -> Path:
|
||||||
|
"""Save a shard to disk (.pt if torch available, else .json)."""
|
||||||
|
shards_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
try:
|
||||||
|
import torch # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
path = shards_dir / f"shard_{shard_id:05d}.pt"
|
||||||
|
torch.save(examples, path)
|
||||||
|
except ImportError:
|
||||||
|
path = shards_dir / f"shard_{shard_id:05d}.json"
|
||||||
|
with open(path, "w") as f:
|
||||||
|
json.dump(examples, f)
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
def _load_shard(path: Path) -> list[dict[str, Any]]:
|
||||||
|
"""Load a shard from disk (.pt or .json)."""
|
||||||
|
if path.suffix == ".pt":
|
||||||
|
import torch # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
|
||||||
|
return result
|
||||||
|
with open(path) as f:
|
||||||
|
result = json.load(f)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _shard_format() -> str:
|
||||||
|
"""Return the shard file extension based on available libraries."""
|
||||||
|
try:
|
||||||
|
import torch # type: ignore[import-untyped] # noqa: F401
|
||||||
|
|
||||||
|
return ".pt"
|
||||||
|
except ImportError:
|
||||||
|
return ".json"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shard worker (module-level for pickling)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
|
||||||
|
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
|
||||||
|
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||||
|
|
||||||
|
t0 = time.monotonic()
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=spec.seed)
|
||||||
|
rng = np.random.default_rng(spec.seed + 1)
|
||||||
|
|
||||||
|
tiers = list(spec.complexity_distribution.keys())
|
||||||
|
probs_list = [spec.complexity_distribution[t] for t in tiers]
|
||||||
|
total = sum(probs_list)
|
||||||
|
probs = [p / total for p in probs_list]
|
||||||
|
|
||||||
|
examples: list[dict[str, Any]] = []
|
||||||
|
for i in range(spec.count):
|
||||||
|
tier_idx = int(rng.choice(len(tiers), p=probs))
|
||||||
|
tier = tiers[tier_idx]
|
||||||
|
try:
|
||||||
|
batch = gen.generate_training_batch(
|
||||||
|
batch_size=1,
|
||||||
|
complexity_tier=tier, # type: ignore[arg-type]
|
||||||
|
grounded_ratio=spec.grounded_ratio,
|
||||||
|
)
|
||||||
|
ex = batch[0]
|
||||||
|
ex["example_id"] = spec.start_example_id + i
|
||||||
|
ex["complexity_tier"] = tier
|
||||||
|
examples.append(ex)
|
||||||
|
except Exception:
|
||||||
|
logger.warning(
|
||||||
|
"Shard %d, example %d failed — skipping",
|
||||||
|
spec.shard_id,
|
||||||
|
i,
|
||||||
|
exc_info=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
shards_dir = Path(output_dir) / "shards"
|
||||||
|
path = _save_shard(spec.shard_id, examples, shards_dir)
|
||||||
|
|
||||||
|
elapsed = time.monotonic() - t0
|
||||||
|
return ShardResult(
|
||||||
|
shard_id=spec.shard_id,
|
||||||
|
num_examples=len(examples),
|
||||||
|
file_path=str(path),
|
||||||
|
generation_time_s=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Minimal YAML parser
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_scalar(value: str) -> int | float | bool | str:
|
||||||
|
"""Parse a YAML scalar value."""
|
||||||
|
# Strip inline comments (space + #)
|
||||||
|
if " #" in value:
|
||||||
|
value = value[: value.index(" #")].strip()
|
||||||
|
elif " #" in value:
|
||||||
|
value = value[: value.index(" #")].strip()
|
||||||
|
v = value.strip()
|
||||||
|
if v.lower() in ("true", "yes"):
|
||||||
|
return True
|
||||||
|
if v.lower() in ("false", "no"):
|
||||||
|
return False
|
||||||
|
try:
|
||||||
|
return int(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
return float(v)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
return v.strip("'\"")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_simple_yaml(path: str) -> dict[str, Any]:
|
||||||
|
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
|
||||||
|
|
||||||
|
This is **not** a full YAML parser. It handles the structure of
|
||||||
|
``configs/dataset/synthetic.yaml``.
|
||||||
|
"""
|
||||||
|
result: dict[str, Any] = {}
|
||||||
|
current_key: str | None = None
|
||||||
|
|
||||||
|
with open(path) as f:
|
||||||
|
for raw_line in f:
|
||||||
|
line = raw_line.rstrip()
|
||||||
|
|
||||||
|
# Skip blank lines and full-line comments
|
||||||
|
if not line or line.lstrip().startswith("#"):
|
||||||
|
continue
|
||||||
|
|
||||||
|
indent = len(line) - len(line.lstrip())
|
||||||
|
|
||||||
|
if indent == 0 and ":" in line:
|
||||||
|
key, _, value = line.partition(":")
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
if value:
|
||||||
|
result[key] = _parse_scalar(value)
|
||||||
|
current_key = None
|
||||||
|
else:
|
||||||
|
current_key = key
|
||||||
|
result[key] = {}
|
||||||
|
continue
|
||||||
|
|
||||||
|
if indent > 0 and line.lstrip().startswith("- "):
|
||||||
|
item = line.lstrip()[2:].strip()
|
||||||
|
if current_key is not None:
|
||||||
|
if isinstance(result.get(current_key), dict) and not result[current_key]:
|
||||||
|
result[current_key] = []
|
||||||
|
if isinstance(result.get(current_key), list):
|
||||||
|
result[current_key].append(_parse_scalar(item))
|
||||||
|
continue
|
||||||
|
|
||||||
|
if indent > 0 and ":" in line and current_key is not None:
|
||||||
|
k, _, v = line.partition(":")
|
||||||
|
k = k.strip()
|
||||||
|
v = v.strip()
|
||||||
|
if v:
|
||||||
|
# Strip inline comments
|
||||||
|
if " #" in v:
|
||||||
|
v = v[: v.index(" #")].strip()
|
||||||
|
if not isinstance(result.get(current_key), dict):
|
||||||
|
result[current_key] = {}
|
||||||
|
result[current_key][k] = _parse_scalar(v)
|
||||||
|
continue
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Dataset generator orchestrator
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetGenerator:
|
||||||
|
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
|
||||||
|
|
||||||
|
def __init__(self, config: DatasetConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.output_path = Path(config.output_dir)
|
||||||
|
self.shards_dir = self.output_path / "shards"
|
||||||
|
self.checkpoint_file = self.output_path / ".checkpoint.json"
|
||||||
|
self.index_file = self.output_path / "index.json"
|
||||||
|
self.stats_file = self.output_path / "stats.json"
|
||||||
|
|
||||||
|
# -- public API --
|
||||||
|
|
||||||
|
def run(self) -> None:
|
||||||
|
"""Generate the full dataset."""
|
||||||
|
self.output_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.shards_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
shards = self._plan_shards()
|
||||||
|
total_shards = len(shards)
|
||||||
|
|
||||||
|
# Resume: find already-completed shards
|
||||||
|
completed: set[int] = set()
|
||||||
|
if self.config.resume:
|
||||||
|
completed = self._find_completed_shards()
|
||||||
|
|
||||||
|
pending = [s for s in shards if s.shard_id not in completed]
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.info("All %d shards already complete.", total_shards)
|
||||||
|
else:
|
||||||
|
logger.info(
|
||||||
|
"Generating %d shards (%d already complete).",
|
||||||
|
len(pending),
|
||||||
|
len(completed),
|
||||||
|
)
|
||||||
|
progress = _make_progress(len(pending))
|
||||||
|
workers = max(1, self.config.num_workers)
|
||||||
|
checkpoint_counter = 0
|
||||||
|
|
||||||
|
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(
|
||||||
|
_generate_shard_worker,
|
||||||
|
spec,
|
||||||
|
str(self.output_path),
|
||||||
|
): spec.shard_id
|
||||||
|
for spec in pending
|
||||||
|
}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
shard_id = futures[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
completed.add(result.shard_id)
|
||||||
|
logger.debug(
|
||||||
|
"Shard %d: %d examples in %.1fs",
|
||||||
|
result.shard_id,
|
||||||
|
result.num_examples,
|
||||||
|
result.generation_time_s,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logger.error("Shard %d failed", shard_id, exc_info=True)
|
||||||
|
progress.update(1)
|
||||||
|
checkpoint_counter += 1
|
||||||
|
if checkpoint_counter >= self.config.checkpoint_every:
|
||||||
|
self._update_checkpoint(completed, total_shards)
|
||||||
|
checkpoint_counter = 0
|
||||||
|
|
||||||
|
progress.close()
|
||||||
|
|
||||||
|
# Finalize
|
||||||
|
self._build_index()
|
||||||
|
stats = self._compute_statistics()
|
||||||
|
self._write_statistics(stats)
|
||||||
|
self._print_summary(stats)
|
||||||
|
|
||||||
|
# Remove checkpoint (generation complete)
|
||||||
|
if self.checkpoint_file.exists():
|
||||||
|
self.checkpoint_file.unlink()
|
||||||
|
|
||||||
|
# -- internal helpers --
|
||||||
|
|
||||||
|
def _plan_shards(self) -> list[ShardSpec]:
|
||||||
|
"""Divide num_assemblies into shards."""
|
||||||
|
n = self.config.num_assemblies
|
||||||
|
size = self.config.shard_size
|
||||||
|
num_shards = math.ceil(n / size)
|
||||||
|
shards: list[ShardSpec] = []
|
||||||
|
for i in range(num_shards):
|
||||||
|
start = i * size
|
||||||
|
count = min(size, n - start)
|
||||||
|
shards.append(
|
||||||
|
ShardSpec(
|
||||||
|
shard_id=i,
|
||||||
|
start_example_id=start,
|
||||||
|
count=count,
|
||||||
|
seed=_derive_shard_seed(self.config.seed, i),
|
||||||
|
complexity_distribution=dict(self.config.complexity_distribution),
|
||||||
|
body_count_min=self.config.body_count_min,
|
||||||
|
body_count_max=self.config.body_count_max,
|
||||||
|
grounded_ratio=self.config.grounded_ratio,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return shards
|
||||||
|
|
||||||
|
def _find_completed_shards(self) -> set[int]:
|
||||||
|
"""Scan shards directory for existing shard files."""
|
||||||
|
completed: set[int] = set()
|
||||||
|
if not self.shards_dir.exists():
|
||||||
|
return completed
|
||||||
|
|
||||||
|
for p in self.shards_dir.iterdir():
|
||||||
|
if p.stem.startswith("shard_"):
|
||||||
|
try:
|
||||||
|
shard_id = int(p.stem.split("_")[1])
|
||||||
|
# Verify file is non-empty
|
||||||
|
if p.stat().st_size > 0:
|
||||||
|
completed.add(shard_id)
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
pass
|
||||||
|
return completed
|
||||||
|
|
||||||
|
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
|
||||||
|
"""Write checkpoint file."""
|
||||||
|
data = {
|
||||||
|
"completed_shards": sorted(completed),
|
||||||
|
"total_shards": total_shards,
|
||||||
|
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||||
|
}
|
||||||
|
with open(self.checkpoint_file, "w") as f:
|
||||||
|
json.dump(data, f)
|
||||||
|
|
||||||
|
def _build_index(self) -> None:
|
||||||
|
"""Build index.json mapping shard files to assembly ID ranges."""
|
||||||
|
shards_info: dict[str, dict[str, int]] = {}
|
||||||
|
total_assemblies = 0
|
||||||
|
|
||||||
|
for p in sorted(self.shards_dir.iterdir()):
|
||||||
|
if not p.stem.startswith("shard_"):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
shard_id = int(p.stem.split("_")[1])
|
||||||
|
except (ValueError, IndexError):
|
||||||
|
continue
|
||||||
|
examples = _load_shard(p)
|
||||||
|
count = len(examples)
|
||||||
|
start_id = shard_id * self.config.shard_size
|
||||||
|
shards_info[p.name] = {"start_id": start_id, "count": count}
|
||||||
|
total_assemblies += count
|
||||||
|
|
||||||
|
fmt = _shard_format().lstrip(".")
|
||||||
|
index = {
|
||||||
|
"format_version": 1,
|
||||||
|
"total_assemblies": total_assemblies,
|
||||||
|
"total_shards": len(shards_info),
|
||||||
|
"shard_format": fmt,
|
||||||
|
"shards": shards_info,
|
||||||
|
}
|
||||||
|
with open(self.index_file, "w") as f:
|
||||||
|
json.dump(index, f, indent=2)
|
||||||
|
|
||||||
|
def _compute_statistics(self) -> dict[str, Any]:
|
||||||
|
"""Aggregate statistics across all shards."""
|
||||||
|
classification_counts: dict[str, int] = {}
|
||||||
|
body_count_hist: dict[int, int] = {}
|
||||||
|
joint_type_counts: dict[str, int] = {}
|
||||||
|
dof_values: list[int] = []
|
||||||
|
degeneracy_values: list[int] = []
|
||||||
|
rigid_count = 0
|
||||||
|
minimally_rigid_count = 0
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
for p in sorted(self.shards_dir.iterdir()):
|
||||||
|
if not p.stem.startswith("shard_"):
|
||||||
|
continue
|
||||||
|
examples = _load_shard(p)
|
||||||
|
for ex in examples:
|
||||||
|
total += 1
|
||||||
|
cls = str(ex.get("assembly_classification", "unknown"))
|
||||||
|
classification_counts[cls] = classification_counts.get(cls, 0) + 1
|
||||||
|
nb = int(ex.get("n_bodies", 0))
|
||||||
|
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
|
||||||
|
for j in ex.get("joints", []):
|
||||||
|
jt = str(j.get("type", "unknown"))
|
||||||
|
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
|
||||||
|
dof_values.append(int(ex.get("internal_dof", 0)))
|
||||||
|
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
|
||||||
|
if ex.get("is_rigid"):
|
||||||
|
rigid_count += 1
|
||||||
|
if ex.get("is_minimally_rigid"):
|
||||||
|
minimally_rigid_count += 1
|
||||||
|
|
||||||
|
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
|
||||||
|
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"total_examples": total,
|
||||||
|
"classification_distribution": dict(sorted(classification_counts.items())),
|
||||||
|
"body_count_histogram": dict(sorted(body_count_hist.items())),
|
||||||
|
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
|
||||||
|
"dof_statistics": {
|
||||||
|
"mean": float(dof_arr.mean()),
|
||||||
|
"std": float(dof_arr.std()),
|
||||||
|
"min": int(dof_arr.min()),
|
||||||
|
"max": int(dof_arr.max()),
|
||||||
|
"median": float(np.median(dof_arr)),
|
||||||
|
},
|
||||||
|
"geometric_degeneracy": {
|
||||||
|
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
|
||||||
|
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
|
||||||
|
"mean_degeneracies": float(deg_arr.mean()),
|
||||||
|
},
|
||||||
|
"rigidity": {
|
||||||
|
"rigid_count": rigid_count,
|
||||||
|
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
|
||||||
|
"minimally_rigid_count": minimally_rigid_count,
|
||||||
|
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
def _write_statistics(self, stats: dict[str, Any]) -> None:
|
||||||
|
"""Write stats.json."""
|
||||||
|
with open(self.stats_file, "w") as f:
|
||||||
|
json.dump(stats, f, indent=2)
|
||||||
|
|
||||||
|
def _print_summary(self, stats: dict[str, Any]) -> None:
|
||||||
|
"""Print a human-readable summary to stdout."""
|
||||||
|
print("\n=== Dataset Generation Summary ===")
|
||||||
|
print(f"Total examples: {stats['total_examples']}")
|
||||||
|
print(f"Output directory: {self.output_path}")
|
||||||
|
print()
|
||||||
|
print("Classification distribution:")
|
||||||
|
for cls, count in stats["classification_distribution"].items():
|
||||||
|
frac = count / max(stats["total_examples"], 1) * 100
|
||||||
|
print(f" {cls}: {count} ({frac:.1f}%)")
|
||||||
|
print()
|
||||||
|
print("Joint type distribution:")
|
||||||
|
for jt, count in stats["joint_type_distribution"].items():
|
||||||
|
print(f" {jt}: {count}")
|
||||||
|
print()
|
||||||
|
dof = stats["dof_statistics"]
|
||||||
|
print(
|
||||||
|
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
|
||||||
|
)
|
||||||
|
rig = stats["rigidity"]
|
||||||
|
print(
|
||||||
|
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
|
||||||
|
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
|
||||||
|
f"{rig['minimally_rigid_count']} minimally rigid"
|
||||||
|
)
|
||||||
|
deg = stats["geometric_degeneracy"]
|
||||||
|
print(
|
||||||
|
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
|
||||||
|
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
|
||||||
|
)
|
||||||
893
solver/datagen/generator.py
Normal file
893
solver/datagen/generator.py
Normal file
@@ -0,0 +1,893 @@
|
|||||||
|
"""Synthetic assembly graph generator for training data production.
|
||||||
|
|
||||||
|
Generates assembly graphs with known constraint classifications using
|
||||||
|
the pebble game and Jacobian verification. Each assembly is fully labeled
|
||||||
|
with per-constraint independence flags and assembly-level classification.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING, Literal
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.spatial.transform import Rotation
|
||||||
|
|
||||||
|
from solver.datagen.analysis import analyze_assembly
|
||||||
|
from solver.datagen.labeling import label_assembly
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"COMPLEXITY_RANGES",
|
||||||
|
"AxisStrategy",
|
||||||
|
"ComplexityTier",
|
||||||
|
"SyntheticAssemblyGenerator",
|
||||||
|
]
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
ComplexityTier = Literal["simple", "medium", "complex"]
|
||||||
|
|
||||||
|
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
|
||||||
|
"simple": (2, 6),
|
||||||
|
"medium": (6, 16),
|
||||||
|
"complex": (16, 51),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Axis sampling strategies
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
|
||||||
|
|
||||||
|
|
||||||
|
class SyntheticAssemblyGenerator:
|
||||||
|
"""Generates assembly graphs with known minimal constraint sets.
|
||||||
|
|
||||||
|
Uses the pebble game to incrementally build assemblies, tracking
|
||||||
|
exactly which constraints are independent at each step. This produces
|
||||||
|
labeled training data: (assembly_graph, constraint_set, labels).
|
||||||
|
|
||||||
|
Labels per constraint:
|
||||||
|
- independent: bool (does this constraint remove a DOF?)
|
||||||
|
- redundant: bool (is this constraint overconstrained?)
|
||||||
|
- minimal_set: bool (part of a minimal rigidity basis?)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, seed: int = 42) -> None:
|
||||||
|
self.rng = np.random.default_rng(seed)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Private helpers
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _random_position(self, scale: float = 5.0) -> np.ndarray:
|
||||||
|
"""Generate random 3D position within [-scale, scale] cube."""
|
||||||
|
return self.rng.uniform(-scale, scale, size=3)
|
||||||
|
|
||||||
|
def _random_axis(self) -> np.ndarray:
|
||||||
|
"""Generate random normalized 3D axis."""
|
||||||
|
axis = self.rng.standard_normal(3)
|
||||||
|
axis /= np.linalg.norm(axis)
|
||||||
|
return axis
|
||||||
|
|
||||||
|
def _random_orientation(self) -> np.ndarray:
|
||||||
|
"""Generate a random 3x3 rotation matrix."""
|
||||||
|
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
|
||||||
|
return mat
|
||||||
|
|
||||||
|
def _cardinal_axis(self) -> np.ndarray:
|
||||||
|
"""Pick uniformly from the six signed cardinal directions."""
|
||||||
|
axes = np.array(
|
||||||
|
[
|
||||||
|
[1, 0, 0],
|
||||||
|
[-1, 0, 0],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, -1, 0],
|
||||||
|
[0, 0, 1],
|
||||||
|
[0, 0, -1],
|
||||||
|
],
|
||||||
|
dtype=float,
|
||||||
|
)
|
||||||
|
result: np.ndarray = axes[int(self.rng.integers(6))]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def _near_parallel_axis(
|
||||||
|
self,
|
||||||
|
base_axis: np.ndarray,
|
||||||
|
perturbation_scale: float = 0.05,
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""Return *base_axis* with a small random perturbation, re-normalized."""
|
||||||
|
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
|
||||||
|
return perturbed / np.linalg.norm(perturbed)
|
||||||
|
|
||||||
|
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
|
||||||
|
"""Sample a joint axis using the specified strategy."""
|
||||||
|
if strategy == "cardinal":
|
||||||
|
return self._cardinal_axis()
|
||||||
|
if strategy == "near_parallel":
|
||||||
|
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
|
||||||
|
return self._random_axis()
|
||||||
|
|
||||||
|
def _resolve_axis(
|
||||||
|
self,
|
||||||
|
strategy: AxisStrategy,
|
||||||
|
parallel_axis_prob: float,
|
||||||
|
shared_axis: np.ndarray | None,
|
||||||
|
) -> tuple[np.ndarray, np.ndarray | None]:
|
||||||
|
"""Return (axis_for_this_joint, shared_axis_to_propagate).
|
||||||
|
|
||||||
|
On the first call where *shared_axis* is ``None`` and parallel
|
||||||
|
injection triggers, a base axis is chosen and returned as
|
||||||
|
*shared_axis* for subsequent calls.
|
||||||
|
"""
|
||||||
|
if shared_axis is not None:
|
||||||
|
return self._near_parallel_axis(shared_axis), shared_axis
|
||||||
|
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
|
||||||
|
base = self._sample_axis(strategy)
|
||||||
|
return base.copy(), base
|
||||||
|
return self._sample_axis(strategy), None
|
||||||
|
|
||||||
|
def _select_joint_type(
|
||||||
|
self,
|
||||||
|
joint_types: JointType | list[JointType],
|
||||||
|
) -> JointType:
|
||||||
|
"""Select a joint type from a single type or list."""
|
||||||
|
if isinstance(joint_types, list):
|
||||||
|
idx = int(self.rng.integers(0, len(joint_types)))
|
||||||
|
return joint_types[idx]
|
||||||
|
return joint_types
|
||||||
|
|
||||||
|
def _create_joint(
|
||||||
|
self,
|
||||||
|
joint_id: int,
|
||||||
|
body_a_id: int,
|
||||||
|
body_b_id: int,
|
||||||
|
pos_a: np.ndarray,
|
||||||
|
pos_b: np.ndarray,
|
||||||
|
joint_type: JointType,
|
||||||
|
*,
|
||||||
|
axis: np.ndarray | None = None,
|
||||||
|
orient_a: np.ndarray | None = None,
|
||||||
|
orient_b: np.ndarray | None = None,
|
||||||
|
) -> Joint:
|
||||||
|
"""Create a joint between two bodies.
|
||||||
|
|
||||||
|
When orientations are provided, anchor points are offset from
|
||||||
|
each body's center along a random local direction rotated into
|
||||||
|
world frame, rather than placed at the midpoint.
|
||||||
|
"""
|
||||||
|
if orient_a is not None and orient_b is not None:
|
||||||
|
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
|
||||||
|
offset_scale = dist * 0.2
|
||||||
|
local_a = self.rng.standard_normal(3) * offset_scale
|
||||||
|
local_b = self.rng.standard_normal(3) * offset_scale
|
||||||
|
anchor_a = pos_a + orient_a @ local_a
|
||||||
|
anchor_b = pos_b + orient_b @ local_b
|
||||||
|
else:
|
||||||
|
anchor = (pos_a + pos_b) / 2.0
|
||||||
|
anchor_a = anchor
|
||||||
|
anchor_b = anchor
|
||||||
|
|
||||||
|
return Joint(
|
||||||
|
joint_id=joint_id,
|
||||||
|
body_a=body_a_id,
|
||||||
|
body_b=body_b_id,
|
||||||
|
joint_type=joint_type,
|
||||||
|
anchor_a=anchor_a,
|
||||||
|
anchor_b=anchor_b,
|
||||||
|
axis=axis if axis is not None else self._random_axis(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Original generators (chain / rigid / overconstrained)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_chain_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
joint_type: JointType = JointType.REVOLUTE,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a serial kinematic chain.
|
||||||
|
|
||||||
|
Simple but useful: each body connects to the next with the
|
||||||
|
specified joint type. Results in an underconstrained assembly
|
||||||
|
(serial chain is never rigid without closing loops).
|
||||||
|
"""
|
||||||
|
bodies = []
|
||||||
|
joints = []
|
||||||
|
|
||||||
|
for i in range(n_bodies):
|
||||||
|
pos = np.array([i * 2.0, 0.0, 0.0])
|
||||||
|
bodies.append(
|
||||||
|
RigidBody(
|
||||||
|
body_id=i,
|
||||||
|
position=pos,
|
||||||
|
orientation=self._random_orientation(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for i in range(n_bodies - 1):
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
i,
|
||||||
|
i,
|
||||||
|
i + 1,
|
||||||
|
bodies[i].position,
|
||||||
|
bodies[i + 1].position,
|
||||||
|
joint_type,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[i].orientation,
|
||||||
|
orient_b=bodies[i + 1].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis = analyze_assembly(
|
||||||
|
bodies,
|
||||||
|
joints,
|
||||||
|
ground_body=0 if grounded else None,
|
||||||
|
)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
def generate_rigid_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a minimally rigid assembly by adding joints until rigid.
|
||||||
|
|
||||||
|
Strategy: start with fixed joints on a spanning tree (guarantees
|
||||||
|
rigidity), then randomly relax some to weaker joint types while
|
||||||
|
maintaining rigidity via the pebble game check.
|
||||||
|
"""
|
||||||
|
bodies = []
|
||||||
|
for i in range(n_bodies):
|
||||||
|
bodies.append(
|
||||||
|
RigidBody(
|
||||||
|
body_id=i,
|
||||||
|
position=self._random_position(),
|
||||||
|
orientation=self._random_orientation(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build spanning tree with fixed joints (overconstrained)
|
||||||
|
joints: list[Joint] = []
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for i in range(1, n_bodies):
|
||||||
|
parent = int(self.rng.integers(0, i))
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
i - 1,
|
||||||
|
parent,
|
||||||
|
i,
|
||||||
|
bodies[parent].position,
|
||||||
|
bodies[i].position,
|
||||||
|
JointType.FIXED,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[parent].orientation,
|
||||||
|
orient_b=bodies[i].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try relaxing joints to weaker types while maintaining rigidity
|
||||||
|
weaker_types = [
|
||||||
|
JointType.REVOLUTE,
|
||||||
|
JointType.CYLINDRICAL,
|
||||||
|
JointType.BALL,
|
||||||
|
]
|
||||||
|
|
||||||
|
ground = 0 if grounded else None
|
||||||
|
for idx in self.rng.permutation(len(joints)):
|
||||||
|
original_type = joints[idx].joint_type
|
||||||
|
for candidate in weaker_types:
|
||||||
|
joints[idx].joint_type = candidate
|
||||||
|
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||||
|
if analysis.is_rigid:
|
||||||
|
break # Keep the weaker type
|
||||||
|
else:
|
||||||
|
joints[idx].joint_type = original_type
|
||||||
|
|
||||||
|
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
def generate_overconstrained_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
extra_joints: int = 2,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate an assembly with known redundant constraints.
|
||||||
|
|
||||||
|
Starts with a rigid assembly, then adds extra joints that
|
||||||
|
the pebble game will flag as redundant.
|
||||||
|
"""
|
||||||
|
bodies, joints, _ = self.generate_rigid_assembly(
|
||||||
|
n_bodies,
|
||||||
|
grounded=grounded,
|
||||||
|
axis_strategy=axis_strategy,
|
||||||
|
parallel_axis_prob=parallel_axis_prob,
|
||||||
|
)
|
||||||
|
|
||||||
|
joint_id = len(joints)
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for _ in range(extra_joints):
|
||||||
|
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||||
|
_overcon_types = [
|
||||||
|
JointType.REVOLUTE,
|
||||||
|
JointType.FIXED,
|
||||||
|
JointType.BALL,
|
||||||
|
]
|
||||||
|
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
joint_id,
|
||||||
|
int(a),
|
||||||
|
int(b),
|
||||||
|
bodies[int(a)].position,
|
||||||
|
bodies[int(b)].position,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[int(a)].orientation,
|
||||||
|
orient_b=bodies[int(b)].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
joint_id += 1
|
||||||
|
|
||||||
|
ground = 0 if grounded else None
|
||||||
|
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# New topology generators
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_tree_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||||
|
branching_factor: int = 3,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a random tree topology with configurable branching.
|
||||||
|
|
||||||
|
Creates a tree where each body can have up to *branching_factor*
|
||||||
|
children. Different branches can use different joint types if a
|
||||||
|
list is provided. Always underconstrained (no closed loops).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: Total bodies (root + children).
|
||||||
|
joint_types: Single type or list to sample from per joint.
|
||||||
|
branching_factor: Max children per parent (1-5 recommended).
|
||||||
|
"""
|
||||||
|
bodies: list[RigidBody] = [
|
||||||
|
RigidBody(
|
||||||
|
body_id=0,
|
||||||
|
position=np.zeros(3),
|
||||||
|
orientation=self._random_orientation(),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
joints: list[Joint] = []
|
||||||
|
|
||||||
|
available_parents = [0]
|
||||||
|
next_id = 1
|
||||||
|
joint_id = 0
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
|
||||||
|
while next_id < n_bodies and available_parents:
|
||||||
|
pidx = int(self.rng.integers(0, len(available_parents)))
|
||||||
|
parent_id = available_parents[pidx]
|
||||||
|
parent_pos = bodies[parent_id].position
|
||||||
|
|
||||||
|
max_children = min(branching_factor, n_bodies - next_id)
|
||||||
|
n_children = int(self.rng.integers(1, max_children + 1))
|
||||||
|
|
||||||
|
for _ in range(n_children):
|
||||||
|
direction = self._random_axis()
|
||||||
|
distance = self.rng.uniform(1.5, 3.0)
|
||||||
|
child_pos = parent_pos + direction * distance
|
||||||
|
child_orient = self._random_orientation()
|
||||||
|
|
||||||
|
bodies.append(
|
||||||
|
RigidBody(
|
||||||
|
body_id=next_id,
|
||||||
|
position=child_pos,
|
||||||
|
orientation=child_orient,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
jtype = self._select_joint_type(joint_types)
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
joint_id,
|
||||||
|
parent_id,
|
||||||
|
next_id,
|
||||||
|
parent_pos,
|
||||||
|
child_pos,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[parent_id].orientation,
|
||||||
|
orient_b=child_orient,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
available_parents.append(next_id)
|
||||||
|
next_id += 1
|
||||||
|
joint_id += 1
|
||||||
|
if next_id >= n_bodies:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Retire parent if it reached branching limit or randomly
|
||||||
|
if n_children >= branching_factor or self.rng.random() < 0.3:
|
||||||
|
available_parents.pop(pidx)
|
||||||
|
|
||||||
|
analysis = analyze_assembly(
|
||||||
|
bodies,
|
||||||
|
joints,
|
||||||
|
ground_body=0 if grounded else None,
|
||||||
|
)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
def generate_loop_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a single closed loop (ring) of bodies.
|
||||||
|
|
||||||
|
The closing constraint introduces redundancy, making this
|
||||||
|
useful for generating overconstrained training data.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: Bodies in the ring (>= 3).
|
||||||
|
joint_types: Single type or list to sample from per joint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If *n_bodies* < 3.
|
||||||
|
"""
|
||||||
|
if n_bodies < 3:
|
||||||
|
msg = "Loop assembly requires at least 3 bodies"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
bodies: list[RigidBody] = []
|
||||||
|
joints: list[Joint] = []
|
||||||
|
|
||||||
|
base_radius = max(2.0, n_bodies * 0.4)
|
||||||
|
for i in range(n_bodies):
|
||||||
|
angle = 2 * np.pi * i / n_bodies
|
||||||
|
radius = base_radius + self.rng.uniform(-0.5, 0.5)
|
||||||
|
x = radius * np.cos(angle)
|
||||||
|
y = radius * np.sin(angle)
|
||||||
|
z = float(self.rng.uniform(-0.2, 0.2))
|
||||||
|
bodies.append(
|
||||||
|
RigidBody(
|
||||||
|
body_id=i,
|
||||||
|
position=np.array([x, y, z]),
|
||||||
|
orientation=self._random_orientation(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for i in range(n_bodies):
|
||||||
|
next_i = (i + 1) % n_bodies
|
||||||
|
jtype = self._select_joint_type(joint_types)
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
i,
|
||||||
|
i,
|
||||||
|
next_i,
|
||||||
|
bodies[i].position,
|
||||||
|
bodies[next_i].position,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[i].orientation,
|
||||||
|
orient_b=bodies[next_i].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis = analyze_assembly(
|
||||||
|
bodies,
|
||||||
|
joints,
|
||||||
|
ground_body=0 if grounded else None,
|
||||||
|
)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
def generate_star_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a star topology with central hub and satellites.
|
||||||
|
|
||||||
|
Body 0 is the hub; all other bodies connect directly to it.
|
||||||
|
Underconstrained because there are no inter-satellite connections.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: Total bodies including hub (>= 2).
|
||||||
|
joint_types: Single type or list to sample from per joint.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If *n_bodies* < 2.
|
||||||
|
"""
|
||||||
|
if n_bodies < 2:
|
||||||
|
msg = "Star assembly requires at least 2 bodies"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
hub_orient = self._random_orientation()
|
||||||
|
bodies: list[RigidBody] = [
|
||||||
|
RigidBody(
|
||||||
|
body_id=0,
|
||||||
|
position=np.zeros(3),
|
||||||
|
orientation=hub_orient,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
joints: list[Joint] = []
|
||||||
|
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for i in range(1, n_bodies):
|
||||||
|
direction = self._random_axis()
|
||||||
|
distance = self.rng.uniform(2.0, 5.0)
|
||||||
|
pos = direction * distance
|
||||||
|
sat_orient = self._random_orientation()
|
||||||
|
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
|
||||||
|
|
||||||
|
jtype = self._select_joint_type(joint_types)
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
i - 1,
|
||||||
|
0,
|
||||||
|
i,
|
||||||
|
np.zeros(3),
|
||||||
|
pos,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=hub_orient,
|
||||||
|
orient_b=sat_orient,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
analysis = analyze_assembly(
|
||||||
|
bodies,
|
||||||
|
joints,
|
||||||
|
ground_body=0 if grounded else None,
|
||||||
|
)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
def generate_mixed_assembly(
|
||||||
|
self,
|
||||||
|
n_bodies: int,
|
||||||
|
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||||
|
edge_density: float = 0.3,
|
||||||
|
*,
|
||||||
|
grounded: bool = True,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||||
|
"""Generate a mixed topology combining tree and loop elements.
|
||||||
|
|
||||||
|
Builds a spanning tree for connectivity, then adds extra edges
|
||||||
|
based on *edge_density* to create loops and redundancy.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_bodies: Number of bodies.
|
||||||
|
joint_types: Single type or list to sample from per joint.
|
||||||
|
edge_density: Fraction of non-tree edges to add (0.0-1.0).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If *edge_density* not in [0.0, 1.0].
|
||||||
|
"""
|
||||||
|
if not 0.0 <= edge_density <= 1.0:
|
||||||
|
msg = "edge_density must be in [0.0, 1.0]"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
bodies: list[RigidBody] = []
|
||||||
|
joints: list[Joint] = []
|
||||||
|
|
||||||
|
for i in range(n_bodies):
|
||||||
|
bodies.append(
|
||||||
|
RigidBody(
|
||||||
|
body_id=i,
|
||||||
|
position=self._random_position(),
|
||||||
|
orientation=self._random_orientation(),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Phase 1: spanning tree
|
||||||
|
joint_id = 0
|
||||||
|
existing_edges: set[frozenset[int]] = set()
|
||||||
|
shared_axis: np.ndarray | None = None
|
||||||
|
for i in range(1, n_bodies):
|
||||||
|
parent = int(self.rng.integers(0, i))
|
||||||
|
jtype = self._select_joint_type(joint_types)
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
joint_id,
|
||||||
|
parent,
|
||||||
|
i,
|
||||||
|
bodies[parent].position,
|
||||||
|
bodies[i].position,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[parent].orientation,
|
||||||
|
orient_b=bodies[i].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing_edges.add(frozenset([parent, i]))
|
||||||
|
joint_id += 1
|
||||||
|
|
||||||
|
# Phase 2: extra edges based on density
|
||||||
|
candidates: list[tuple[int, int]] = []
|
||||||
|
for i in range(n_bodies):
|
||||||
|
for j in range(i + 1, n_bodies):
|
||||||
|
if frozenset([i, j]) not in existing_edges:
|
||||||
|
candidates.append((i, j))
|
||||||
|
|
||||||
|
n_extra = int(edge_density * len(candidates))
|
||||||
|
self.rng.shuffle(candidates)
|
||||||
|
|
||||||
|
for a, b in candidates[:n_extra]:
|
||||||
|
jtype = self._select_joint_type(joint_types)
|
||||||
|
axis, shared_axis = self._resolve_axis(
|
||||||
|
axis_strategy,
|
||||||
|
parallel_axis_prob,
|
||||||
|
shared_axis,
|
||||||
|
)
|
||||||
|
joints.append(
|
||||||
|
self._create_joint(
|
||||||
|
joint_id,
|
||||||
|
a,
|
||||||
|
b,
|
||||||
|
bodies[a].position,
|
||||||
|
bodies[b].position,
|
||||||
|
jtype,
|
||||||
|
axis=axis,
|
||||||
|
orient_a=bodies[a].orientation,
|
||||||
|
orient_b=bodies[b].orientation,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
joint_id += 1
|
||||||
|
|
||||||
|
analysis = analyze_assembly(
|
||||||
|
bodies,
|
||||||
|
joints,
|
||||||
|
ground_body=0 if grounded else None,
|
||||||
|
)
|
||||||
|
return bodies, joints, analysis
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Batch generation
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def generate_training_batch(
|
||||||
|
self,
|
||||||
|
batch_size: int = 100,
|
||||||
|
n_bodies_range: tuple[int, int] | None = None,
|
||||||
|
complexity_tier: ComplexityTier | None = None,
|
||||||
|
*,
|
||||||
|
axis_strategy: AxisStrategy = "random",
|
||||||
|
parallel_axis_prob: float = 0.0,
|
||||||
|
grounded_ratio: float = 1.0,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Generate a batch of labeled training examples.
|
||||||
|
|
||||||
|
Each example contains body positions, joint descriptions,
|
||||||
|
per-joint independence labels, and assembly-level classification.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch_size: Number of assemblies to generate.
|
||||||
|
n_bodies_range: ``(min, max_exclusive)`` body count.
|
||||||
|
Overridden by *complexity_tier* when both are given.
|
||||||
|
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
|
||||||
|
/ ``"complex"``). Overrides *n_bodies_range*.
|
||||||
|
axis_strategy: Axis sampling strategy for joint axes.
|
||||||
|
parallel_axis_prob: Probability of parallel axis injection.
|
||||||
|
grounded_ratio: Fraction of examples that are grounded.
|
||||||
|
"""
|
||||||
|
if complexity_tier is not None:
|
||||||
|
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
|
||||||
|
elif n_bodies_range is None:
|
||||||
|
n_bodies_range = (3, 8)
|
||||||
|
|
||||||
|
_joint_pool = [
|
||||||
|
JointType.REVOLUTE,
|
||||||
|
JointType.BALL,
|
||||||
|
JointType.CYLINDRICAL,
|
||||||
|
JointType.FIXED,
|
||||||
|
]
|
||||||
|
|
||||||
|
geo_kw: dict[str, Any] = {
|
||||||
|
"axis_strategy": axis_strategy,
|
||||||
|
"parallel_axis_prob": parallel_axis_prob,
|
||||||
|
}
|
||||||
|
|
||||||
|
examples: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for i in range(batch_size):
|
||||||
|
n = int(self.rng.integers(*n_bodies_range))
|
||||||
|
gen_idx = int(self.rng.integers(7))
|
||||||
|
grounded = bool(self.rng.random() < grounded_ratio)
|
||||||
|
|
||||||
|
if gen_idx == 0:
|
||||||
|
_chain_types = [
|
||||||
|
JointType.REVOLUTE,
|
||||||
|
JointType.BALL,
|
||||||
|
JointType.CYLINDRICAL,
|
||||||
|
]
|
||||||
|
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||||
|
bodies, joints, analysis = self.generate_chain_assembly(
|
||||||
|
n,
|
||||||
|
jtype,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "chain"
|
||||||
|
elif gen_idx == 1:
|
||||||
|
bodies, joints, analysis = self.generate_rigid_assembly(
|
||||||
|
n,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "rigid"
|
||||||
|
elif gen_idx == 2:
|
||||||
|
extra = int(self.rng.integers(1, 4))
|
||||||
|
bodies, joints, analysis = self.generate_overconstrained_assembly(
|
||||||
|
n,
|
||||||
|
extra,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "overconstrained"
|
||||||
|
elif gen_idx == 3:
|
||||||
|
branching = int(self.rng.integers(2, 5))
|
||||||
|
bodies, joints, analysis = self.generate_tree_assembly(
|
||||||
|
n,
|
||||||
|
_joint_pool,
|
||||||
|
branching,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "tree"
|
||||||
|
elif gen_idx == 4:
|
||||||
|
n = max(n, 3)
|
||||||
|
bodies, joints, analysis = self.generate_loop_assembly(
|
||||||
|
n,
|
||||||
|
_joint_pool,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "loop"
|
||||||
|
elif gen_idx == 5:
|
||||||
|
n = max(n, 2)
|
||||||
|
bodies, joints, analysis = self.generate_star_assembly(
|
||||||
|
n,
|
||||||
|
_joint_pool,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "star"
|
||||||
|
else:
|
||||||
|
density = float(self.rng.uniform(0.2, 0.5))
|
||||||
|
bodies, joints, analysis = self.generate_mixed_assembly(
|
||||||
|
n,
|
||||||
|
_joint_pool,
|
||||||
|
density,
|
||||||
|
grounded=grounded,
|
||||||
|
**geo_kw,
|
||||||
|
)
|
||||||
|
gen_name = "mixed"
|
||||||
|
|
||||||
|
# Produce ground truth labels (includes ConstraintAnalysis)
|
||||||
|
ground = 0 if grounded else None
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=ground)
|
||||||
|
analysis = labels.analysis
|
||||||
|
|
||||||
|
# Build per-joint labels from edge results
|
||||||
|
joint_labels: dict[int, dict[str, int]] = {}
|
||||||
|
for result in analysis.per_edge_results:
|
||||||
|
jid = result["joint_id"]
|
||||||
|
if jid not in joint_labels:
|
||||||
|
joint_labels[jid] = {
|
||||||
|
"independent_constraints": 0,
|
||||||
|
"redundant_constraints": 0,
|
||||||
|
"total_constraints": 0,
|
||||||
|
}
|
||||||
|
joint_labels[jid]["total_constraints"] += 1
|
||||||
|
if result["independent"]:
|
||||||
|
joint_labels[jid]["independent_constraints"] += 1
|
||||||
|
else:
|
||||||
|
joint_labels[jid]["redundant_constraints"] += 1
|
||||||
|
|
||||||
|
examples.append(
|
||||||
|
{
|
||||||
|
"example_id": i,
|
||||||
|
"generator_type": gen_name,
|
||||||
|
"grounded": grounded,
|
||||||
|
"n_bodies": len(bodies),
|
||||||
|
"n_joints": len(joints),
|
||||||
|
"body_positions": [b.position.tolist() for b in bodies],
|
||||||
|
"body_orientations": [b.orientation.tolist() for b in bodies],
|
||||||
|
"joints": [
|
||||||
|
{
|
||||||
|
"joint_id": j.joint_id,
|
||||||
|
"body_a": j.body_a,
|
||||||
|
"body_b": j.body_b,
|
||||||
|
"type": j.joint_type.name,
|
||||||
|
"axis": j.axis.tolist(),
|
||||||
|
}
|
||||||
|
for j in joints
|
||||||
|
],
|
||||||
|
"joint_labels": joint_labels,
|
||||||
|
"labels": labels.to_dict(),
|
||||||
|
"assembly_classification": (analysis.combinatorial_classification),
|
||||||
|
"is_rigid": analysis.is_rigid,
|
||||||
|
"is_minimally_rigid": analysis.is_minimally_rigid,
|
||||||
|
"internal_dof": analysis.jacobian_internal_dof,
|
||||||
|
"geometric_degeneracies": (analysis.geometric_degeneracies),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return examples
|
||||||
517
solver/datagen/jacobian.py
Normal file
517
solver/datagen/jacobian.py
Normal file
@@ -0,0 +1,517 @@
|
|||||||
|
"""Numerical Jacobian rank verification for assembly constraint analysis.
|
||||||
|
|
||||||
|
Builds the constraint Jacobian matrix and analyzes its numerical rank
|
||||||
|
to detect geometric degeneracies that the combinatorial pebble game
|
||||||
|
cannot identify (e.g., parallel revolute axes creating hidden dependencies).
|
||||||
|
|
||||||
|
References:
|
||||||
|
- Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D"
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from solver.datagen.types import Joint, JointType, RigidBody
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = ["JacobianVerifier"]
|
||||||
|
|
||||||
|
|
||||||
|
class JacobianVerifier:
|
||||||
|
"""Builds and analyzes the constraint Jacobian for numerical rank check.
|
||||||
|
|
||||||
|
The pebble game gives a combinatorial *necessary* condition for
|
||||||
|
rigidity. However, geometric special cases (e.g., all revolute axes
|
||||||
|
parallel, creating a hidden dependency) require numerical verification.
|
||||||
|
|
||||||
|
For each joint, we construct the constraint Jacobian rows that map
|
||||||
|
the 6n-dimensional generalized velocity vector to the constraint
|
||||||
|
violation rates. The rank of this Jacobian equals the number of
|
||||||
|
truly independent constraints.
|
||||||
|
|
||||||
|
The generalized velocity vector for n bodies is::
|
||||||
|
|
||||||
|
v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z]
|
||||||
|
|
||||||
|
Each scalar constraint C_i contributes one row to J such that::
|
||||||
|
|
||||||
|
dC_i/dt = J_i @ v = 0
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, bodies: list[RigidBody]) -> None:
|
||||||
|
self.bodies = {b.body_id: b for b in bodies}
|
||||||
|
self.body_index = {b.body_id: i for i, b in enumerate(bodies)}
|
||||||
|
self.n_bodies = len(bodies)
|
||||||
|
self.jacobian_rows: list[np.ndarray] = []
|
||||||
|
self.row_labels: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _body_cols(self, body_id: int) -> tuple[int, int]:
|
||||||
|
"""Return the column range [start, end) for a body in J."""
|
||||||
|
idx = self.body_index[body_id]
|
||||||
|
return idx * 6, (idx + 1) * 6
|
||||||
|
|
||||||
|
def add_joint_constraints(self, joint: Joint) -> int:
|
||||||
|
"""Add Jacobian rows for all scalar constraints of a joint.
|
||||||
|
|
||||||
|
Returns the number of rows added.
|
||||||
|
"""
|
||||||
|
builder = {
|
||||||
|
JointType.FIXED: self._build_fixed,
|
||||||
|
JointType.REVOLUTE: self._build_revolute,
|
||||||
|
JointType.CYLINDRICAL: self._build_cylindrical,
|
||||||
|
JointType.SLIDER: self._build_slider,
|
||||||
|
JointType.BALL: self._build_ball,
|
||||||
|
JointType.PLANAR: self._build_planar,
|
||||||
|
JointType.DISTANCE: self._build_distance,
|
||||||
|
JointType.PARALLEL: self._build_parallel,
|
||||||
|
JointType.PERPENDICULAR: self._build_perpendicular,
|
||||||
|
JointType.UNIVERSAL: self._build_universal,
|
||||||
|
JointType.SCREW: self._build_screw,
|
||||||
|
}
|
||||||
|
|
||||||
|
rows_before = len(self.jacobian_rows)
|
||||||
|
builder[joint.joint_type](joint)
|
||||||
|
return len(self.jacobian_rows) - rows_before
|
||||||
|
|
||||||
|
def _make_row(self) -> np.ndarray:
|
||||||
|
"""Create a zero row of width 6*n_bodies."""
|
||||||
|
return np.zeros(6 * self.n_bodies)
|
||||||
|
|
||||||
|
def _skew(self, v: np.ndarray) -> np.ndarray:
|
||||||
|
"""Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``."""
|
||||||
|
return np.array(
|
||||||
|
[
|
||||||
|
[0, -v[2], v[1]],
|
||||||
|
[v[2], 0, -v[0]],
|
||||||
|
[-v[1], v[0], 0],
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Ball-and-socket (spherical) joint: 3 translation constraints ---
|
||||||
|
|
||||||
|
def _build_ball(self, joint: Joint) -> None:
|
||||||
|
"""Ball joint: coincident point constraint.
|
||||||
|
|
||||||
|
``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0``
|
||||||
|
(3 equations)
|
||||||
|
|
||||||
|
Jacobian rows (for each of x, y, z):
|
||||||
|
body_a linear: -I
|
||||||
|
body_a angular: +skew(R_a @ r_a)
|
||||||
|
body_b linear: +I
|
||||||
|
body_b angular: -skew(R_b @ r_b)
|
||||||
|
"""
|
||||||
|
# Use anchor positions directly as world-frame offsets
|
||||||
|
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||||
|
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||||
|
|
||||||
|
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for axis_idx in range(3):
|
||||||
|
row = self._make_row()
|
||||||
|
e = np.zeros(3)
|
||||||
|
e[axis_idx] = 1.0
|
||||||
|
|
||||||
|
row[col_a_start : col_a_start + 3] = -e
|
||||||
|
row[col_a_start + 3 : col_a_end] = np.cross(r_a, e)
|
||||||
|
|
||||||
|
row[col_b_start : col_b_start + 3] = e
|
||||||
|
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e)
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "ball_translation",
|
||||||
|
"axis": axis_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Fixed joint: 3 translation + 3 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_fixed(self, joint: Joint) -> None:
|
||||||
|
"""Fixed joint = ball joint + locked rotation.
|
||||||
|
|
||||||
|
Translation part: same as ball joint (3 rows).
|
||||||
|
Rotation part: relative angular velocity must be zero (3 rows).
|
||||||
|
"""
|
||||||
|
self._build_ball(joint)
|
||||||
|
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for axis_idx in range(3):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 + axis_idx] = -1.0
|
||||||
|
row[col_b_start + 3 + axis_idx] = 1.0
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "fixed_rotation",
|
||||||
|
"axis": axis_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Revolute (hinge) joint: 3 translation + 2 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_revolute(self, joint: Joint) -> None:
|
||||||
|
"""Revolute joint: rotation only about one axis.
|
||||||
|
|
||||||
|
Translation: same as ball (3 rows).
|
||||||
|
Rotation: relative angular velocity must be parallel to hinge axis.
|
||||||
|
"""
|
||||||
|
self._build_ball(joint)
|
||||||
|
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
t1, t2 = self._perpendicular_pair(axis)
|
||||||
|
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for i, t in enumerate((t1, t2)):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = t
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "revolute_rotation",
|
||||||
|
"perp_axis": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Cylindrical joint: 2 translation + 2 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_cylindrical(self, joint: Joint) -> None:
|
||||||
|
"""Cylindrical joint: allows rotation + translation along one axis.
|
||||||
|
|
||||||
|
Translation: constrain motion perpendicular to axis (2 rows).
|
||||||
|
Rotation: constrain rotation perpendicular to axis (2 rows).
|
||||||
|
"""
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
t1, t2 = self._perpendicular_pair(axis)
|
||||||
|
|
||||||
|
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||||
|
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||||
|
|
||||||
|
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for i, t in enumerate((t1, t2)):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start : col_a_start + 3] = -t
|
||||||
|
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||||
|
row[col_b_start : col_b_start + 3] = t
|
||||||
|
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "cylindrical_translation",
|
||||||
|
"perp_axis": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate((t1, t2)):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = t
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "cylindrical_rotation",
|
||||||
|
"perp_axis": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Slider (prismatic) joint: 2 translation + 3 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_slider(self, joint: Joint) -> None:
|
||||||
|
"""Slider/prismatic joint: translation along one axis only.
|
||||||
|
|
||||||
|
Translation: perpendicular translation constrained (2 rows).
|
||||||
|
Rotation: all relative rotation constrained (3 rows).
|
||||||
|
"""
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
t1, t2 = self._perpendicular_pair(axis)
|
||||||
|
|
||||||
|
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||||
|
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||||
|
|
||||||
|
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for i, t in enumerate((t1, t2)):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start : col_a_start + 3] = -t
|
||||||
|
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||||
|
row[col_b_start : col_b_start + 3] = t
|
||||||
|
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "slider_translation",
|
||||||
|
"perp_axis": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for axis_idx in range(3):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 + axis_idx] = -1.0
|
||||||
|
row[col_b_start + 3 + axis_idx] = 1.0
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "slider_rotation",
|
||||||
|
"axis": axis_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Planar joint: 1 translation + 2 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_planar(self, joint: Joint) -> None:
|
||||||
|
"""Planar joint: constrains to a plane.
|
||||||
|
|
||||||
|
Translation: motion along plane normal constrained (1 row).
|
||||||
|
Rotation: rotation about axes in the plane constrained (2 rows).
|
||||||
|
"""
|
||||||
|
normal = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
t1, t2 = self._perpendicular_pair(normal)
|
||||||
|
|
||||||
|
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||||
|
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||||
|
|
||||||
|
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start : col_a_start + 3] = -normal
|
||||||
|
row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal)
|
||||||
|
row[col_b_start : col_b_start + 3] = normal
|
||||||
|
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal)
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "planar_translation",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, t in enumerate((t1, t2)):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = t
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "planar_rotation",
|
||||||
|
"perp_axis": i,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Distance constraint: 1 scalar ---
|
||||||
|
|
||||||
|
def _build_distance(self, joint: Joint) -> None:
|
||||||
|
"""Distance constraint: ``||p_b - p_a|| = d``.
|
||||||
|
|
||||||
|
Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0``
|
||||||
|
where ``direction = normalized(p_b - p_a)``.
|
||||||
|
"""
|
||||||
|
p_a = joint.anchor_a
|
||||||
|
p_b = joint.anchor_b
|
||||||
|
diff = p_b - p_a
|
||||||
|
dist = np.linalg.norm(diff)
|
||||||
|
direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist
|
||||||
|
|
||||||
|
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||||
|
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||||
|
|
||||||
|
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start : col_a_start + 3] = -direction
|
||||||
|
row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction)
|
||||||
|
row[col_b_start : col_b_start + 3] = direction
|
||||||
|
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction)
|
||||||
|
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "distance",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Parallel constraint: 3 rotation constraints ---
|
||||||
|
|
||||||
|
def _build_parallel(self, joint: Joint) -> None:
|
||||||
|
"""Parallel: all relative rotation constrained (same as fixed rotation).
|
||||||
|
|
||||||
|
In practice only 2 of 3 are independent for a single axis, but
|
||||||
|
we emit 3 and let the rank check sort it out.
|
||||||
|
"""
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
for axis_idx in range(3):
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 + axis_idx] = -1.0
|
||||||
|
row[col_b_start + 3 + axis_idx] = 1.0
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "parallel_rotation",
|
||||||
|
"axis": axis_idx,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Perpendicular constraint: 1 angular ---
|
||||||
|
|
||||||
|
def _build_perpendicular(self, joint: Joint) -> None:
|
||||||
|
"""Perpendicular: single dot-product angular constraint."""
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "perpendicular",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Universal (Cardan) joint: 3 translation + 1 rotation ---
|
||||||
|
|
||||||
|
def _build_universal(self, joint: Joint) -> None:
|
||||||
|
"""Universal joint: ball + one rotation constraint.
|
||||||
|
|
||||||
|
Allows rotation about two axes, constrains rotation about the third.
|
||||||
|
"""
|
||||||
|
self._build_ball(joint)
|
||||||
|
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "universal_rotation",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled ---
|
||||||
|
|
||||||
|
def _build_screw(self, joint: Joint) -> None:
|
||||||
|
"""Screw joint: coupled rotation-translation along axis.
|
||||||
|
|
||||||
|
Like cylindrical but with a coupling constraint:
|
||||||
|
``v_axial - pitch * w_axial = 0``
|
||||||
|
"""
|
||||||
|
self._build_cylindrical(joint)
|
||||||
|
|
||||||
|
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||||
|
col_a_start, _ = self._body_cols(joint.body_a)
|
||||||
|
col_b_start, _ = self._body_cols(joint.body_b)
|
||||||
|
|
||||||
|
row = self._make_row()
|
||||||
|
row[col_a_start : col_a_start + 3] = -axis
|
||||||
|
row[col_b_start : col_b_start + 3] = axis
|
||||||
|
row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis
|
||||||
|
row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis
|
||||||
|
self.jacobian_rows.append(row)
|
||||||
|
self.row_labels.append(
|
||||||
|
{
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"type": "screw_coupling",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# --- Utilities ---
|
||||||
|
|
||||||
|
def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||||
|
"""Generate two unit vectors perpendicular to *axis* and each other."""
|
||||||
|
if abs(axis[0]) < 0.9:
|
||||||
|
t1 = np.cross(axis, np.array([1.0, 0, 0]))
|
||||||
|
else:
|
||||||
|
t1 = np.cross(axis, np.array([0, 1.0, 0]))
|
||||||
|
t1 /= np.linalg.norm(t1)
|
||||||
|
t2 = np.cross(axis, t1)
|
||||||
|
t2 /= np.linalg.norm(t2)
|
||||||
|
return t1, t2
|
||||||
|
|
||||||
|
def get_jacobian(self) -> np.ndarray:
|
||||||
|
"""Return the full constraint Jacobian matrix."""
|
||||||
|
if not self.jacobian_rows:
|
||||||
|
return np.zeros((0, 6 * self.n_bodies))
|
||||||
|
return np.array(self.jacobian_rows)
|
||||||
|
|
||||||
|
def numerical_rank(self, tol: float = 1e-8) -> int:
|
||||||
|
"""Compute the numerical rank of the constraint Jacobian via SVD.
|
||||||
|
|
||||||
|
This is the number of truly independent scalar constraints,
|
||||||
|
accounting for geometric degeneracies that the combinatorial
|
||||||
|
pebble game cannot detect.
|
||||||
|
"""
|
||||||
|
j = self.get_jacobian()
|
||||||
|
if j.size == 0:
|
||||||
|
return 0
|
||||||
|
sv = np.linalg.svd(j, compute_uv=False)
|
||||||
|
return int(np.sum(sv > tol))
|
||||||
|
|
||||||
|
def find_dependencies(self, tol: float = 1e-8) -> list[int]:
|
||||||
|
"""Identify which constraint rows are numerically dependent.
|
||||||
|
|
||||||
|
Returns indices of rows that can be removed without changing
|
||||||
|
the Jacobian's rank.
|
||||||
|
"""
|
||||||
|
j = self.get_jacobian()
|
||||||
|
if j.size == 0:
|
||||||
|
return []
|
||||||
|
|
||||||
|
n_rows = j.shape[0]
|
||||||
|
dependent: list[int] = []
|
||||||
|
|
||||||
|
current = np.zeros((0, j.shape[1]))
|
||||||
|
current_rank = 0
|
||||||
|
|
||||||
|
for i in range(n_rows):
|
||||||
|
candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :]
|
||||||
|
sv = np.linalg.svd(candidate, compute_uv=False)
|
||||||
|
new_rank = int(np.sum(sv > tol))
|
||||||
|
|
||||||
|
if new_rank > current_rank:
|
||||||
|
current = candidate
|
||||||
|
current_rank = new_rank
|
||||||
|
else:
|
||||||
|
dependent.append(i)
|
||||||
|
|
||||||
|
return dependent
|
||||||
394
solver/datagen/labeling.py
Normal file
394
solver/datagen/labeling.py
Normal file
@@ -0,0 +1,394 @@
|
|||||||
|
"""Ground truth labeling pipeline for synthetic assemblies.
|
||||||
|
|
||||||
|
Produces rich per-constraint, per-joint, per-body, and assembly-level
|
||||||
|
labels by running both the pebble game and Jacobian verification and
|
||||||
|
correlating their results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from solver.datagen.jacobian import JacobianVerifier
|
||||||
|
from solver.datagen.pebble_game import PebbleGame3D
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = ["AssemblyLabels", "label_assembly"]
|
||||||
|
|
||||||
|
_GROUND_ID = -1
|
||||||
|
_SVD_TOL = 1e-8
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Label dataclasses
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConstraintLabel:
|
||||||
|
"""Per scalar-constraint label combining both analysis methods."""
|
||||||
|
|
||||||
|
joint_id: int
|
||||||
|
constraint_idx: int
|
||||||
|
pebble_independent: bool
|
||||||
|
jacobian_independent: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JointLabel:
|
||||||
|
"""Aggregated constraint counts for a single joint."""
|
||||||
|
|
||||||
|
joint_id: int
|
||||||
|
independent_count: int
|
||||||
|
redundant_count: int
|
||||||
|
total: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BodyDofLabel:
|
||||||
|
"""Per-body DOF signature from nullspace projection."""
|
||||||
|
|
||||||
|
body_id: int
|
||||||
|
translational_dof: int
|
||||||
|
rotational_dof: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssemblyLabel:
|
||||||
|
"""Assembly-wide summary label."""
|
||||||
|
|
||||||
|
classification: str
|
||||||
|
total_dof: int
|
||||||
|
redundant_count: int
|
||||||
|
is_rigid: bool
|
||||||
|
is_minimally_rigid: bool
|
||||||
|
has_degeneracy: bool
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssemblyLabels:
|
||||||
|
"""Complete ground truth labels for an assembly."""
|
||||||
|
|
||||||
|
per_constraint: list[ConstraintLabel]
|
||||||
|
per_joint: list[JointLabel]
|
||||||
|
per_body: list[BodyDofLabel]
|
||||||
|
assembly: AssemblyLabel
|
||||||
|
analysis: ConstraintAnalysis
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""Return a JSON-serializable dict."""
|
||||||
|
return {
|
||||||
|
"per_constraint": [
|
||||||
|
{
|
||||||
|
"joint_id": c.joint_id,
|
||||||
|
"constraint_idx": c.constraint_idx,
|
||||||
|
"pebble_independent": c.pebble_independent,
|
||||||
|
"jacobian_independent": c.jacobian_independent,
|
||||||
|
}
|
||||||
|
for c in self.per_constraint
|
||||||
|
],
|
||||||
|
"per_joint": [
|
||||||
|
{
|
||||||
|
"joint_id": j.joint_id,
|
||||||
|
"independent_count": j.independent_count,
|
||||||
|
"redundant_count": j.redundant_count,
|
||||||
|
"total": j.total,
|
||||||
|
}
|
||||||
|
for j in self.per_joint
|
||||||
|
],
|
||||||
|
"per_body": [
|
||||||
|
{
|
||||||
|
"body_id": b.body_id,
|
||||||
|
"translational_dof": b.translational_dof,
|
||||||
|
"rotational_dof": b.rotational_dof,
|
||||||
|
}
|
||||||
|
for b in self.per_body
|
||||||
|
],
|
||||||
|
"assembly": {
|
||||||
|
"classification": self.assembly.classification,
|
||||||
|
"total_dof": self.assembly.total_dof,
|
||||||
|
"redundant_count": self.assembly.redundant_count,
|
||||||
|
"is_rigid": self.assembly.is_rigid,
|
||||||
|
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
||||||
|
"has_degeneracy": self.assembly.has_degeneracy,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-body DOF from nullspace projection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _compute_per_body_dof(
|
||||||
|
j_reduced: np.ndarray,
|
||||||
|
body_ids: list[int],
|
||||||
|
ground_body: int | None,
|
||||||
|
body_index: dict[int, int],
|
||||||
|
) -> list[BodyDofLabel]:
|
||||||
|
"""Compute translational and rotational DOF per body.
|
||||||
|
|
||||||
|
Uses SVD nullspace projection: for each body, extract its
|
||||||
|
translational (3 cols) and rotational (3 cols) components
|
||||||
|
from the nullspace basis and compute ranks.
|
||||||
|
"""
|
||||||
|
# Build column index mapping for the reduced Jacobian
|
||||||
|
# (ground body columns have been removed)
|
||||||
|
col_map: dict[int, int] = {}
|
||||||
|
col_idx = 0
|
||||||
|
for bid in body_ids:
|
||||||
|
if bid == ground_body:
|
||||||
|
continue
|
||||||
|
col_map[bid] = col_idx
|
||||||
|
col_idx += 1
|
||||||
|
|
||||||
|
results: list[BodyDofLabel] = []
|
||||||
|
|
||||||
|
if j_reduced.size == 0:
|
||||||
|
# No constraints — every body is fully free
|
||||||
|
for bid in body_ids:
|
||||||
|
if bid == ground_body:
|
||||||
|
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||||
|
else:
|
||||||
|
results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Full SVD to get nullspace
|
||||||
|
_u, s, vh = np.linalg.svd(j_reduced, full_matrices=True)
|
||||||
|
rank = int(np.sum(s > _SVD_TOL))
|
||||||
|
n_cols = j_reduced.shape[1]
|
||||||
|
|
||||||
|
if rank >= n_cols:
|
||||||
|
# Fully constrained — no nullspace
|
||||||
|
for bid in body_ids:
|
||||||
|
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||||
|
return results
|
||||||
|
|
||||||
|
# Nullspace basis: rows of Vh beyond the rank
|
||||||
|
nullspace = vh[rank:] # shape: (n_cols - rank, n_cols)
|
||||||
|
|
||||||
|
for bid in body_ids:
|
||||||
|
if bid == ground_body:
|
||||||
|
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||||
|
continue
|
||||||
|
|
||||||
|
idx = col_map[bid]
|
||||||
|
trans_cols = nullspace[:, idx * 6 : idx * 6 + 3]
|
||||||
|
rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6]
|
||||||
|
|
||||||
|
# Rank of each block = DOF in that category
|
||||||
|
if trans_cols.size > 0:
|
||||||
|
sv_t = np.linalg.svd(trans_cols, compute_uv=False)
|
||||||
|
t_dof = int(np.sum(sv_t > _SVD_TOL))
|
||||||
|
else:
|
||||||
|
t_dof = 0
|
||||||
|
|
||||||
|
if rot_cols.size > 0:
|
||||||
|
sv_r = np.linalg.svd(rot_cols, compute_uv=False)
|
||||||
|
r_dof = int(np.sum(sv_r > _SVD_TOL))
|
||||||
|
else:
|
||||||
|
r_dof = 0
|
||||||
|
|
||||||
|
results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Main labeling function
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def label_assembly(
|
||||||
|
bodies: list[RigidBody],
|
||||||
|
joints: list[Joint],
|
||||||
|
ground_body: int | None = None,
|
||||||
|
) -> AssemblyLabels:
|
||||||
|
"""Produce complete ground truth labels for an assembly.
|
||||||
|
|
||||||
|
Runs both the pebble game and Jacobian verification internally,
|
||||||
|
then correlates their results into per-constraint, per-joint,
|
||||||
|
per-body, and assembly-level labels.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bodies: Rigid bodies in the assembly.
|
||||||
|
joints: Joints connecting the bodies.
|
||||||
|
ground_body: If set, this body is fixed to the world.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AssemblyLabels with full label set and embedded ConstraintAnalysis.
|
||||||
|
"""
|
||||||
|
# ---- Pebble Game ----
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
all_edge_results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
if ground_body is not None:
|
||||||
|
pg.add_body(_GROUND_ID)
|
||||||
|
|
||||||
|
for body in bodies:
|
||||||
|
pg.add_body(body.body_id)
|
||||||
|
|
||||||
|
if ground_body is not None:
|
||||||
|
ground_joint = Joint(
|
||||||
|
joint_id=-1,
|
||||||
|
body_a=ground_body,
|
||||||
|
body_b=_GROUND_ID,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||||
|
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||||
|
)
|
||||||
|
pg.add_joint(ground_joint)
|
||||||
|
|
||||||
|
for joint in joints:
|
||||||
|
results = pg.add_joint(joint)
|
||||||
|
all_edge_results.extend(results)
|
||||||
|
|
||||||
|
grounded = ground_body is not None
|
||||||
|
combinatorial_independent = len(pg.state.independent_edges)
|
||||||
|
raw_dof = pg.get_dof()
|
||||||
|
ground_offset = 6 if grounded else 0
|
||||||
|
effective_dof = raw_dof - ground_offset
|
||||||
|
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||||
|
|
||||||
|
redundant_count = pg.get_redundant_count()
|
||||||
|
if redundant_count > 0 and effective_internal_dof > 0:
|
||||||
|
classification = "mixed"
|
||||||
|
elif redundant_count > 0:
|
||||||
|
classification = "overconstrained"
|
||||||
|
elif effective_internal_dof > 0:
|
||||||
|
classification = "underconstrained"
|
||||||
|
else:
|
||||||
|
classification = "well-constrained"
|
||||||
|
|
||||||
|
# ---- Jacobian Verification ----
|
||||||
|
verifier = JacobianVerifier(bodies)
|
||||||
|
|
||||||
|
for joint in joints:
|
||||||
|
verifier.add_joint_constraints(joint)
|
||||||
|
|
||||||
|
j_full = verifier.get_jacobian()
|
||||||
|
j_reduced = j_full.copy()
|
||||||
|
if ground_body is not None and j_reduced.size > 0:
|
||||||
|
idx = verifier.body_index[ground_body]
|
||||||
|
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||||
|
j_reduced = np.delete(j_reduced, cols_to_remove, axis=1)
|
||||||
|
|
||||||
|
if j_reduced.size > 0:
|
||||||
|
sv = np.linalg.svd(j_reduced, compute_uv=False)
|
||||||
|
jacobian_rank = int(np.sum(sv > _SVD_TOL))
|
||||||
|
else:
|
||||||
|
jacobian_rank = 0
|
||||||
|
|
||||||
|
n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies)
|
||||||
|
jacobian_nullity = n_cols - jacobian_rank
|
||||||
|
dependent_rows = verifier.find_dependencies()
|
||||||
|
dependent_set = set(dependent_rows)
|
||||||
|
|
||||||
|
trivial_dof = 0 if grounded else 6
|
||||||
|
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||||
|
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||||
|
is_rigid = jacobian_nullity <= trivial_dof
|
||||||
|
is_minimally_rigid = is_rigid and len(dependent_rows) == 0
|
||||||
|
|
||||||
|
# ---- Per-constraint labels ----
|
||||||
|
# Map Jacobian rows to (joint_id, constraint_index).
|
||||||
|
# Rows are added contiguously per joint in the same order as joints.
|
||||||
|
row_to_joint: list[tuple[int, int]] = []
|
||||||
|
for joint in joints:
|
||||||
|
dof = joint.joint_type.dof
|
||||||
|
for ci in range(dof):
|
||||||
|
row_to_joint.append((joint.joint_id, ci))
|
||||||
|
|
||||||
|
per_constraint: list[ConstraintLabel] = []
|
||||||
|
for edge_idx, edge_result in enumerate(all_edge_results):
|
||||||
|
jid = edge_result["joint_id"]
|
||||||
|
ci = edge_result["constraint_index"]
|
||||||
|
pebble_indep = edge_result["independent"]
|
||||||
|
|
||||||
|
# Find matching Jacobian row
|
||||||
|
jacobian_indep = True
|
||||||
|
if edge_idx < len(row_to_joint):
|
||||||
|
row_idx = edge_idx
|
||||||
|
jacobian_indep = row_idx not in dependent_set
|
||||||
|
|
||||||
|
per_constraint.append(
|
||||||
|
ConstraintLabel(
|
||||||
|
joint_id=jid,
|
||||||
|
constraint_idx=ci,
|
||||||
|
pebble_independent=pebble_indep,
|
||||||
|
jacobian_independent=jacobian_indep,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Per-joint labels ----
|
||||||
|
joint_agg: dict[int, JointLabel] = {}
|
||||||
|
for cl in per_constraint:
|
||||||
|
if cl.joint_id not in joint_agg:
|
||||||
|
joint_agg[cl.joint_id] = JointLabel(
|
||||||
|
joint_id=cl.joint_id,
|
||||||
|
independent_count=0,
|
||||||
|
redundant_count=0,
|
||||||
|
total=0,
|
||||||
|
)
|
||||||
|
jl = joint_agg[cl.joint_id]
|
||||||
|
jl.total += 1
|
||||||
|
if cl.pebble_independent:
|
||||||
|
jl.independent_count += 1
|
||||||
|
else:
|
||||||
|
jl.redundant_count += 1
|
||||||
|
|
||||||
|
per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg]
|
||||||
|
|
||||||
|
# ---- Per-body DOF labels ----
|
||||||
|
body_ids = [b.body_id for b in bodies]
|
||||||
|
per_body = _compute_per_body_dof(
|
||||||
|
j_reduced,
|
||||||
|
body_ids,
|
||||||
|
ground_body,
|
||||||
|
verifier.body_index,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- Assembly label ----
|
||||||
|
assembly_label = AssemblyLabel(
|
||||||
|
classification=classification,
|
||||||
|
total_dof=max(0, jacobian_internal_dof),
|
||||||
|
redundant_count=redundant_count,
|
||||||
|
is_rigid=is_rigid,
|
||||||
|
is_minimally_rigid=is_minimally_rigid,
|
||||||
|
has_degeneracy=geometric_degeneracies > 0,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---- ConstraintAnalysis (for backward compat) ----
|
||||||
|
analysis = ConstraintAnalysis(
|
||||||
|
combinatorial_dof=effective_dof,
|
||||||
|
combinatorial_internal_dof=effective_internal_dof,
|
||||||
|
combinatorial_redundant=redundant_count,
|
||||||
|
combinatorial_classification=classification,
|
||||||
|
per_edge_results=all_edge_results,
|
||||||
|
jacobian_rank=jacobian_rank,
|
||||||
|
jacobian_nullity=jacobian_nullity,
|
||||||
|
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||||
|
numerically_dependent=dependent_rows,
|
||||||
|
geometric_degeneracies=geometric_degeneracies,
|
||||||
|
is_rigid=is_rigid,
|
||||||
|
is_minimally_rigid=is_minimally_rigid,
|
||||||
|
)
|
||||||
|
|
||||||
|
return AssemblyLabels(
|
||||||
|
per_constraint=per_constraint,
|
||||||
|
per_joint=per_joint,
|
||||||
|
per_body=per_body,
|
||||||
|
assembly=assembly_label,
|
||||||
|
analysis=analysis,
|
||||||
|
)
|
||||||
258
solver/datagen/pebble_game.py
Normal file
258
solver/datagen/pebble_game.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis.
|
||||||
|
|
||||||
|
Implements the pebble game algorithm adapted for CAD assembly constraint
|
||||||
|
graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints
|
||||||
|
between bodies remove DOF according to their type.
|
||||||
|
|
||||||
|
The pebble game provides a fast combinatorial *necessary* condition for
|
||||||
|
rigidity via Tay's theorem. It does not detect geometric degeneracies —
|
||||||
|
use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient*
|
||||||
|
condition.
|
||||||
|
|
||||||
|
References:
|
||||||
|
- Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008
|
||||||
|
- Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity
|
||||||
|
Percolation: The Pebble Game", J. Comput. Phys., 1997
|
||||||
|
- Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from solver.datagen.types import Joint, PebbleState
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = ["PebbleGame3D"]
|
||||||
|
|
||||||
|
|
||||||
|
class PebbleGame3D:
|
||||||
|
"""Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks.
|
||||||
|
|
||||||
|
For body-bar-hinge structures in 3D, Tay's theorem states that a
|
||||||
|
multigraph G on n vertices is generically minimally rigid iff:
|
||||||
|
|E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2)
|
||||||
|
|
||||||
|
The (6,6)-pebble game tests this sparsity condition incrementally.
|
||||||
|
Each vertex starts with 6 pebbles (representing 6 DOF). To insert
|
||||||
|
an edge, we need to collect 6+1=7 pebbles on its two endpoints.
|
||||||
|
If we can, the edge is independent (removes a DOF). If not, it's
|
||||||
|
redundant (overconstrained).
|
||||||
|
|
||||||
|
In the CAD assembly context:
|
||||||
|
- Vertices = rigid bodies
|
||||||
|
- Edges = scalar constraints from joints
|
||||||
|
- A revolute joint (5 DOF removed) maps to 5 multigraph edges
|
||||||
|
- A fixed joint (6 DOF removed) maps to 6 multigraph edges
|
||||||
|
"""
|
||||||
|
|
||||||
|
K = 6 # Pebbles per vertex (DOF per rigid body in 3D)
|
||||||
|
L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.state = PebbleState()
|
||||||
|
self._edge_counter = 0
|
||||||
|
self._bodies: set[int] = set()
|
||||||
|
|
||||||
|
def add_body(self, body_id: int) -> None:
|
||||||
|
"""Register a rigid body (vertex) with K=6 free pebbles."""
|
||||||
|
if body_id in self._bodies:
|
||||||
|
return
|
||||||
|
self._bodies.add(body_id)
|
||||||
|
self.state.free_pebbles[body_id] = self.K
|
||||||
|
self.state.incoming[body_id] = set()
|
||||||
|
self.state.outgoing[body_id] = set()
|
||||||
|
|
||||||
|
def add_joint(self, joint: Joint) -> list[dict[str, Any]]:
|
||||||
|
"""Expand a joint into multigraph edges and test each for independence.
|
||||||
|
|
||||||
|
A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph.
|
||||||
|
Each edge is tested individually via the pebble game.
|
||||||
|
|
||||||
|
Returns a list of dicts, one per scalar constraint, with:
|
||||||
|
- edge_id: int
|
||||||
|
- independent: bool
|
||||||
|
- dof_remaining: int (total free pebbles after this edge)
|
||||||
|
"""
|
||||||
|
self.add_body(joint.body_a)
|
||||||
|
self.add_body(joint.body_b)
|
||||||
|
|
||||||
|
num_constraints = joint.joint_type.dof
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
for i in range(num_constraints):
|
||||||
|
edge_id = self._edge_counter
|
||||||
|
self._edge_counter += 1
|
||||||
|
|
||||||
|
independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b)
|
||||||
|
total_free = sum(self.state.free_pebbles.values())
|
||||||
|
|
||||||
|
results.append(
|
||||||
|
{
|
||||||
|
"edge_id": edge_id,
|
||||||
|
"joint_id": joint.joint_id,
|
||||||
|
"constraint_index": i,
|
||||||
|
"independent": independent,
|
||||||
|
"dof_remaining": total_free,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool:
|
||||||
|
"""Try to insert a directed edge between u and v.
|
||||||
|
|
||||||
|
The edge is accepted (independent) iff we can collect L+1 = 7
|
||||||
|
pebbles on the two endpoints {u, v} combined.
|
||||||
|
|
||||||
|
If accepted, one pebble is consumed and the edge is directed
|
||||||
|
away from the vertex that gives up the pebble.
|
||||||
|
"""
|
||||||
|
# Count current free pebbles on u and v
|
||||||
|
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||||
|
|
||||||
|
# Try to gather enough pebbles via DFS reachability search
|
||||||
|
if available < self.L + 1:
|
||||||
|
needed = (self.L + 1) - available
|
||||||
|
# Try to free pebbles by searching from u first, then v
|
||||||
|
for target in (u, v):
|
||||||
|
while needed > 0:
|
||||||
|
found = self._search_and_collect(target, frozenset({u, v}))
|
||||||
|
if not found:
|
||||||
|
break
|
||||||
|
needed -= 1
|
||||||
|
|
||||||
|
# Recheck after collection attempts
|
||||||
|
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||||
|
|
||||||
|
if available >= self.L + 1:
|
||||||
|
# Accept: consume a pebble from whichever endpoint has one
|
||||||
|
source = u if self.state.free_pebbles[u] > 0 else v
|
||||||
|
|
||||||
|
self.state.free_pebbles[source] -= 1
|
||||||
|
self.state.directed_edges[edge_id] = (source, v if source == u else u)
|
||||||
|
self.state.outgoing[source].add((edge_id, v if source == u else u))
|
||||||
|
target = v if source == u else u
|
||||||
|
self.state.incoming[target].add((edge_id, source))
|
||||||
|
self.state.independent_edges.add(edge_id)
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
# Reject: edge is redundant (overconstrained)
|
||||||
|
self.state.redundant_edges.add(edge_id)
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool:
|
||||||
|
"""DFS to find a free pebble reachable from *target* and move it.
|
||||||
|
|
||||||
|
Follows directed edges *backwards* (from destination to source)
|
||||||
|
to find a vertex with a free pebble that isn't in *forbidden*.
|
||||||
|
When found, reverses the path to move the pebble to *target*.
|
||||||
|
|
||||||
|
Returns True if a pebble was successfully moved to target.
|
||||||
|
"""
|
||||||
|
# BFS/DFS through the directed graph following outgoing edges
|
||||||
|
# from target. An outgoing edge (target -> w) means target spent
|
||||||
|
# a pebble on that edge. If we can find a vertex with a free
|
||||||
|
# pebble, we reverse edges along the path to move it.
|
||||||
|
|
||||||
|
visited: set[int] = set()
|
||||||
|
# Stack: (current_vertex, path_of_edge_ids_to_reverse)
|
||||||
|
stack: list[tuple[int, list[int]]] = [(target, [])]
|
||||||
|
|
||||||
|
while stack:
|
||||||
|
current, path = stack.pop()
|
||||||
|
if current in visited:
|
||||||
|
continue
|
||||||
|
visited.add(current)
|
||||||
|
|
||||||
|
# Check if current vertex (not in forbidden, not target)
|
||||||
|
# has a free pebble
|
||||||
|
if (
|
||||||
|
current != target
|
||||||
|
and current not in forbidden
|
||||||
|
and self.state.free_pebbles[current] > 0
|
||||||
|
):
|
||||||
|
# Found a pebble — reverse the path
|
||||||
|
self._reverse_path(path, current)
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Follow outgoing edges from current vertex
|
||||||
|
for eid, neighbor in self.state.outgoing.get(current, set()):
|
||||||
|
if neighbor not in visited:
|
||||||
|
stack.append((neighbor, [*path, eid]))
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None:
|
||||||
|
"""Reverse directed edges along a path, moving a pebble to the start.
|
||||||
|
|
||||||
|
The pebble at *pebble_source* is consumed by the last edge in
|
||||||
|
the path, and a pebble is freed at the path's start vertex.
|
||||||
|
"""
|
||||||
|
if not edge_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Reverse each edge in the path
|
||||||
|
for eid in edge_ids:
|
||||||
|
old_source, old_target = self.state.directed_edges[eid]
|
||||||
|
|
||||||
|
# Remove from adjacency
|
||||||
|
self.state.outgoing[old_source].discard((eid, old_target))
|
||||||
|
self.state.incoming[old_target].discard((eid, old_source))
|
||||||
|
|
||||||
|
# Reverse direction
|
||||||
|
self.state.directed_edges[eid] = (old_target, old_source)
|
||||||
|
self.state.outgoing[old_target].add((eid, old_source))
|
||||||
|
self.state.incoming[old_source].add((eid, old_target))
|
||||||
|
|
||||||
|
# Move pebble counts: source loses one, first vertex in path gains one
|
||||||
|
self.state.free_pebbles[pebble_source] -= 1
|
||||||
|
|
||||||
|
# After all reversals, the vertex at the beginning of the
|
||||||
|
# search path gains a pebble
|
||||||
|
_first_src, first_tgt = self.state.directed_edges[edge_ids[0]]
|
||||||
|
self.state.free_pebbles[first_tgt] += 1
|
||||||
|
|
||||||
|
def get_dof(self) -> int:
|
||||||
|
"""Total remaining DOF = sum of free pebbles.
|
||||||
|
|
||||||
|
For a fully rigid assembly, this should be 6 (the trivial rigid
|
||||||
|
body motions of the whole assembly). Internal DOF = total - 6.
|
||||||
|
"""
|
||||||
|
return sum(self.state.free_pebbles.values())
|
||||||
|
|
||||||
|
def get_internal_dof(self) -> int:
|
||||||
|
"""Internal (non-trivial) degrees of freedom."""
|
||||||
|
return max(0, self.get_dof() - 6)
|
||||||
|
|
||||||
|
def is_rigid(self) -> bool:
|
||||||
|
"""Combinatorial rigidity check: rigid iff at most 6 pebbles remain."""
|
||||||
|
return self.get_dof() <= self.L
|
||||||
|
|
||||||
|
def get_redundant_count(self) -> int:
|
||||||
|
"""Number of redundant (overconstrained) scalar constraints."""
|
||||||
|
return len(self.state.redundant_edges)
|
||||||
|
|
||||||
|
def classify_assembly(self, *, grounded: bool = False) -> str:
|
||||||
|
"""Classify the assembly state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
grounded: If True, the baseline trivial DOF is 0 (not 6),
|
||||||
|
because the ground body's 6 DOF were removed.
|
||||||
|
"""
|
||||||
|
total_dof = self.get_dof()
|
||||||
|
redundant = self.get_redundant_count()
|
||||||
|
baseline = 0 if grounded else self.L
|
||||||
|
|
||||||
|
if redundant > 0 and total_dof > baseline:
|
||||||
|
return "mixed" # Both under and over-constrained regions
|
||||||
|
elif redundant > 0:
|
||||||
|
return "overconstrained"
|
||||||
|
elif total_dof > baseline:
|
||||||
|
return "underconstrained"
|
||||||
|
elif total_dof == baseline:
|
||||||
|
return "well-constrained"
|
||||||
|
else:
|
||||||
|
return "overconstrained"
|
||||||
144
solver/datagen/types.py
Normal file
144
solver/datagen/types.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Shared data types for assembly constraint analysis.
|
||||||
|
|
||||||
|
Types ported from the pebble-game synthetic data generator for reuse
|
||||||
|
across the solver package (data generation, training, inference).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import enum
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ConstraintAnalysis",
|
||||||
|
"Joint",
|
||||||
|
"JointType",
|
||||||
|
"PebbleState",
|
||||||
|
"RigidBody",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Joint definitions: each joint type removes a known number of DOF
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class JointType(enum.Enum):
|
||||||
|
"""Standard CAD joint types with their DOF-removal counts.
|
||||||
|
|
||||||
|
Each joint between two 6-DOF rigid bodies removes a specific number
|
||||||
|
of relative degrees of freedom. In the body-bar-hinge multigraph
|
||||||
|
representation, each joint maps to a number of edges equal to the
|
||||||
|
DOF it removes.
|
||||||
|
|
||||||
|
Values are ``(ordinal, dof_removed)`` tuples so that joint types
|
||||||
|
sharing the same DOF count remain distinct enum members. Use the
|
||||||
|
:attr:`dof` property to get the scalar constraint count.
|
||||||
|
"""
|
||||||
|
|
||||||
|
FIXED = (0, 6) # Locks all relative motion
|
||||||
|
REVOLUTE = (1, 5) # Allows rotation about one axis
|
||||||
|
CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis
|
||||||
|
SLIDER = (3, 5) # Allows translation along one axis (prismatic)
|
||||||
|
BALL = (4, 3) # Allows rotation about a point (spherical)
|
||||||
|
PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane
|
||||||
|
SCREW = (6, 5) # Coupled rotation-translation (helical)
|
||||||
|
UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint)
|
||||||
|
PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints)
|
||||||
|
PERPENDICULAR = (9, 1) # Single angular constraint
|
||||||
|
DISTANCE = (10, 1) # Single scalar distance constraint
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dof(self) -> int:
|
||||||
|
"""Number of scalar constraints (DOF removed) by this joint type."""
|
||||||
|
return self.value[1]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Data structures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RigidBody:
|
||||||
|
"""A rigid body in the assembly with pose and geometry info."""
|
||||||
|
|
||||||
|
body_id: int
|
||||||
|
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||||
|
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
|
||||||
|
|
||||||
|
# Anchor points for joints, in local frame
|
||||||
|
# Populated when joints reference specific geometry
|
||||||
|
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Joint:
|
||||||
|
"""A joint connecting two rigid bodies."""
|
||||||
|
|
||||||
|
joint_id: int
|
||||||
|
body_a: int # Index of first body
|
||||||
|
body_b: int # Index of second body
|
||||||
|
joint_type: JointType
|
||||||
|
|
||||||
|
# Joint parameters in world frame
|
||||||
|
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||||
|
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||||
|
axis: np.ndarray = field(
|
||||||
|
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
|
||||||
|
)
|
||||||
|
|
||||||
|
# For screw joints
|
||||||
|
pitch: float = 0.0
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PebbleState:
|
||||||
|
"""Tracks the state of the pebble game on the multigraph."""
|
||||||
|
|
||||||
|
# Number of free pebbles per body (vertex). Starts at 6.
|
||||||
|
free_pebbles: dict[int, int] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Directed edges: edge_id -> (source_body, target_body)
|
||||||
|
# Edge is directed away from the body that "spent" a pebble.
|
||||||
|
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Track which edges are independent vs redundant
|
||||||
|
independent_edges: set[int] = field(default_factory=set)
|
||||||
|
redundant_edges: set[int] = field(default_factory=set)
|
||||||
|
|
||||||
|
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
|
||||||
|
# Following directed edges *towards* a body (incoming edges)
|
||||||
|
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# Outgoing edges from a body
|
||||||
|
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConstraintAnalysis:
|
||||||
|
"""Results of analyzing an assembly's constraint system."""
|
||||||
|
|
||||||
|
# Pebble game (combinatorial) results
|
||||||
|
combinatorial_dof: int
|
||||||
|
combinatorial_internal_dof: int
|
||||||
|
combinatorial_redundant: int
|
||||||
|
combinatorial_classification: str
|
||||||
|
per_edge_results: list[dict[str, Any]]
|
||||||
|
|
||||||
|
# Numerical (Jacobian) results
|
||||||
|
jacobian_rank: int
|
||||||
|
jacobian_nullity: int # = 6n - rank = total DOF
|
||||||
|
jacobian_internal_dof: int # = nullity - 6
|
||||||
|
numerically_dependent: list[int]
|
||||||
|
|
||||||
|
# Combined
|
||||||
|
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
|
||||||
|
is_rigid: bool
|
||||||
|
is_minimally_rigid: bool
|
||||||
0
solver/datasets/__init__.py
Normal file
0
solver/datasets/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
0
solver/models/__init__.py
Normal file
0
solver/models/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Tests for solver.datagen.analysis -- combined analysis function."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from solver.datagen.analysis import analyze_assembly
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _two_bodies() -> list[RigidBody]:
|
||||||
|
return [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _triangle_bodies() -> list[RigidBody]:
|
||||||
|
return [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
RigidBody(2, position=np.array([1.0, 1.7, 0.0])),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoBodiesRevolute:
|
||||||
|
"""Demo scenario 1: two bodies connected by a revolute joint."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def result(self) -> ConstraintAnalysis:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
joints = [
|
||||||
|
Joint(
|
||||||
|
0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return analyze_assembly(bodies, joints, ground_body=0)
|
||||||
|
|
||||||
|
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.jacobian_internal_dof == 1
|
||||||
|
|
||||||
|
def test_not_rigid(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert not result.is_rigid
|
||||||
|
|
||||||
|
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.combinatorial_classification == "underconstrained"
|
||||||
|
|
||||||
|
def test_no_redundant(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.combinatorial_redundant == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoBodiesFixed:
|
||||||
|
"""Demo scenario 2: two bodies connected by a fixed joint."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def result(self) -> ConstraintAnalysis:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
joints = [
|
||||||
|
Joint(
|
||||||
|
0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return analyze_assembly(bodies, joints, ground_body=0)
|
||||||
|
|
||||||
|
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.jacobian_internal_dof == 0
|
||||||
|
|
||||||
|
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.is_rigid
|
||||||
|
|
||||||
|
def test_minimally_rigid(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.is_minimally_rigid
|
||||||
|
|
||||||
|
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.combinatorial_classification == "well-constrained"
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 3: Triangle with revolute joints (overconstrained)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTriangleRevolute:
|
||||||
|
"""Demo scenario 3: triangle of 3 bodies + 3 revolute joints."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def result(self) -> ConstraintAnalysis:
|
||||||
|
bodies = _triangle_bodies()
|
||||||
|
joints = [
|
||||||
|
Joint(
|
||||||
|
0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
Joint(
|
||||||
|
1,
|
||||||
|
body_a=1,
|
||||||
|
body_b=2,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.5, 0.85, 0.0]),
|
||||||
|
anchor_b=np.array([1.5, 0.85, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
Joint(
|
||||||
|
2,
|
||||||
|
body_a=2,
|
||||||
|
body_b=0,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([0.5, 0.85, 0.0]),
|
||||||
|
anchor_b=np.array([0.5, 0.85, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return analyze_assembly(bodies, joints, ground_body=0)
|
||||||
|
|
||||||
|
def test_has_redundant(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.combinatorial_redundant > 0
|
||||||
|
|
||||||
|
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.combinatorial_classification in ("overconstrained", "mixed")
|
||||||
|
|
||||||
|
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert result.is_rigid
|
||||||
|
|
||||||
|
def test_numerically_dependent(self, result: ConstraintAnalysis) -> None:
|
||||||
|
assert len(result.numerically_dependent) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Scenario 4: Parallel revolute axes (geometric degeneracy)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelRevoluteAxes:
|
||||||
|
"""Demo scenario 4: parallel revolute axes create geometric degeneracies."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def result(self) -> ConstraintAnalysis:
|
||||||
|
bodies = [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||||
|
]
|
||||||
|
joints = [
|
||||||
|
Joint(
|
||||||
|
0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
Joint(
|
||||||
|
1,
|
||||||
|
body_a=1,
|
||||||
|
body_b=2,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([3.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([3.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return analyze_assembly(bodies, joints, ground_body=0)
|
||||||
|
|
||||||
|
def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None:
|
||||||
|
"""Parallel axes produce at least one geometric degeneracy."""
|
||||||
|
assert result.geometric_degeneracies > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoJoints:
|
||||||
|
"""Assembly with bodies but no joints."""
|
||||||
|
|
||||||
|
def test_all_dof_free(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
result = analyze_assembly(bodies, [], ground_body=0)
|
||||||
|
# Body 1 is completely free (6 DOF), body 0 is grounded
|
||||||
|
assert result.jacobian_internal_dof > 0
|
||||||
|
assert not result.is_rigid
|
||||||
|
|
||||||
|
def test_ungrounded(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
result = analyze_assembly(bodies, [])
|
||||||
|
assert result.combinatorial_classification == "underconstrained"
|
||||||
|
|
||||||
|
|
||||||
|
class TestReturnType:
|
||||||
|
"""Verify the return object is a proper ConstraintAnalysis."""
|
||||||
|
|
||||||
|
def test_instance(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
joints = [Joint(0, 0, 1, JointType.FIXED)]
|
||||||
|
result = analyze_assembly(bodies, joints)
|
||||||
|
assert isinstance(result, ConstraintAnalysis)
|
||||||
|
|
||||||
|
def test_per_edge_results_populated(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
joints = [Joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
result = analyze_assembly(bodies, joints)
|
||||||
|
assert len(result.per_edge_results) == 5
|
||||||
337
tests/datagen/test_dataset.py
Normal file
337
tests/datagen/test_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
|||||||
|
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from solver.datagen.dataset import (
|
||||||
|
DatasetConfig,
|
||||||
|
DatasetGenerator,
|
||||||
|
_derive_shard_seed,
|
||||||
|
_parse_scalar,
|
||||||
|
parse_simple_yaml,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DatasetConfig
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetConfig:
|
||||||
|
"""DatasetConfig construction and defaults."""
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
cfg = DatasetConfig()
|
||||||
|
assert cfg.num_assemblies == 100_000
|
||||||
|
assert cfg.seed == 42
|
||||||
|
assert cfg.shard_size == 1000
|
||||||
|
assert cfg.num_workers == 4
|
||||||
|
|
||||||
|
def test_from_dict_flat(self) -> None:
|
||||||
|
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
|
||||||
|
cfg = DatasetConfig.from_dict(d)
|
||||||
|
assert cfg.num_assemblies == 500
|
||||||
|
assert cfg.seed == 123
|
||||||
|
|
||||||
|
def test_from_dict_nested_body_count(self) -> None:
|
||||||
|
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
|
||||||
|
cfg = DatasetConfig.from_dict(d)
|
||||||
|
assert cfg.body_count_min == 3
|
||||||
|
assert cfg.body_count_max == 20
|
||||||
|
|
||||||
|
def test_from_dict_flat_body_count(self) -> None:
|
||||||
|
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
|
||||||
|
cfg = DatasetConfig.from_dict(d)
|
||||||
|
assert cfg.body_count_min == 5
|
||||||
|
assert cfg.body_count_max == 30
|
||||||
|
|
||||||
|
def test_from_dict_complexity_distribution(self) -> None:
|
||||||
|
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
|
||||||
|
cfg = DatasetConfig.from_dict(d)
|
||||||
|
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
|
||||||
|
|
||||||
|
def test_from_dict_templates(self) -> None:
|
||||||
|
d: dict[str, Any] = {"templates": ["chain", "tree"]}
|
||||||
|
cfg = DatasetConfig.from_dict(d)
|
||||||
|
assert cfg.templates == ["chain", "tree"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Minimal YAML parser
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseScalar:
|
||||||
|
"""_parse_scalar handles different value types."""
|
||||||
|
|
||||||
|
def test_int(self) -> None:
|
||||||
|
assert _parse_scalar("42") == 42
|
||||||
|
|
||||||
|
def test_float(self) -> None:
|
||||||
|
assert _parse_scalar("3.14") == 3.14
|
||||||
|
|
||||||
|
def test_bool_true(self) -> None:
|
||||||
|
assert _parse_scalar("true") is True
|
||||||
|
|
||||||
|
def test_bool_false(self) -> None:
|
||||||
|
assert _parse_scalar("false") is False
|
||||||
|
|
||||||
|
def test_string(self) -> None:
|
||||||
|
assert _parse_scalar("hello") == "hello"
|
||||||
|
|
||||||
|
def test_inline_comment(self) -> None:
|
||||||
|
assert _parse_scalar("0.4 # some comment") == 0.4
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseSimpleYaml:
|
||||||
|
"""parse_simple_yaml handles the synthetic.yaml format."""
|
||||||
|
|
||||||
|
def test_flat_scalars(self, tmp_path: Path) -> None:
|
||||||
|
yaml_file = tmp_path / "test.yaml"
|
||||||
|
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
|
||||||
|
result = parse_simple_yaml(str(yaml_file))
|
||||||
|
assert result["name"] == "test"
|
||||||
|
assert result["num"] == 42
|
||||||
|
assert result["ratio"] == 0.5
|
||||||
|
|
||||||
|
def test_nested_dict(self, tmp_path: Path) -> None:
|
||||||
|
yaml_file = tmp_path / "test.yaml"
|
||||||
|
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
|
||||||
|
result = parse_simple_yaml(str(yaml_file))
|
||||||
|
assert result["body_count"] == {"min": 2, "max": 50}
|
||||||
|
|
||||||
|
def test_list(self, tmp_path: Path) -> None:
|
||||||
|
yaml_file = tmp_path / "test.yaml"
|
||||||
|
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
|
||||||
|
result = parse_simple_yaml(str(yaml_file))
|
||||||
|
assert result["templates"] == ["chain", "tree", "loop"]
|
||||||
|
|
||||||
|
def test_inline_comments(self, tmp_path: Path) -> None:
|
||||||
|
yaml_file = tmp_path / "test.yaml"
|
||||||
|
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
|
||||||
|
result = parse_simple_yaml(str(yaml_file))
|
||||||
|
assert result["dist"]["simple"] == 0.4
|
||||||
|
|
||||||
|
def test_synthetic_yaml(self) -> None:
|
||||||
|
"""Parse the actual project config."""
|
||||||
|
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
|
||||||
|
assert result["name"] == "synthetic"
|
||||||
|
assert result["num_assemblies"] == 100000
|
||||||
|
assert isinstance(result["complexity_distribution"], dict)
|
||||||
|
assert isinstance(result["templates"], list)
|
||||||
|
assert result["shard_size"] == 1000
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Shard seed derivation
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestShardSeedDerivation:
|
||||||
|
"""_derive_shard_seed is deterministic and unique per shard."""
|
||||||
|
|
||||||
|
def test_deterministic(self) -> None:
|
||||||
|
s1 = _derive_shard_seed(42, 0)
|
||||||
|
s2 = _derive_shard_seed(42, 0)
|
||||||
|
assert s1 == s2
|
||||||
|
|
||||||
|
def test_different_shards(self) -> None:
|
||||||
|
s1 = _derive_shard_seed(42, 0)
|
||||||
|
s2 = _derive_shard_seed(42, 1)
|
||||||
|
assert s1 != s2
|
||||||
|
|
||||||
|
def test_different_global_seeds(self) -> None:
|
||||||
|
s1 = _derive_shard_seed(42, 0)
|
||||||
|
s2 = _derive_shard_seed(99, 0)
|
||||||
|
assert s1 != s2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DatasetGenerator — small end-to-end tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasetGenerator:
|
||||||
|
"""End-to-end tests with small datasets."""
|
||||||
|
|
||||||
|
def test_small_generation(self, tmp_path: Path) -> None:
|
||||||
|
"""Generate 10 examples in a single shard."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=10,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=10,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
shards_dir = tmp_path / "output" / "shards"
|
||||||
|
assert shards_dir.exists()
|
||||||
|
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||||
|
assert len(shard_files) == 1
|
||||||
|
|
||||||
|
index_file = tmp_path / "output" / "index.json"
|
||||||
|
assert index_file.exists()
|
||||||
|
index = json.loads(index_file.read_text())
|
||||||
|
assert index["total_assemblies"] == 10
|
||||||
|
|
||||||
|
stats_file = tmp_path / "output" / "stats.json"
|
||||||
|
assert stats_file.exists()
|
||||||
|
|
||||||
|
def test_multi_shard(self, tmp_path: Path) -> None:
|
||||||
|
"""Generate 20 examples across 2 shards."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=20,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=10,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
shards_dir = tmp_path / "output" / "shards"
|
||||||
|
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||||
|
assert len(shard_files) == 2
|
||||||
|
|
||||||
|
def test_resume_skips_completed(self, tmp_path: Path) -> None:
|
||||||
|
"""Resume skips already-completed shards."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=20,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=10,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
# Record shard modification times
|
||||||
|
shards_dir = tmp_path / "output" / "shards"
|
||||||
|
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
|
||||||
|
|
||||||
|
# Remove stats (simulate incomplete) and re-run
|
||||||
|
(tmp_path / "output" / "stats.json").unlink()
|
||||||
|
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
# Shards should NOT have been regenerated
|
||||||
|
for p in shards_dir.glob("shard_*"):
|
||||||
|
assert p.stat().st_mtime == mtimes[p.name]
|
||||||
|
|
||||||
|
# Stats should be regenerated
|
||||||
|
assert (tmp_path / "output" / "stats.json").exists()
|
||||||
|
|
||||||
|
def test_checkpoint_removed(self, tmp_path: Path) -> None:
|
||||||
|
"""Checkpoint file is cleaned up after completion."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=5,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=5,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
checkpoint = tmp_path / "output" / ".checkpoint.json"
|
||||||
|
assert not checkpoint.exists()
|
||||||
|
|
||||||
|
def test_stats_structure(self, tmp_path: Path) -> None:
|
||||||
|
"""stats.json has expected top-level keys."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=10,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=10,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
|
||||||
|
assert stats["total_examples"] == 10
|
||||||
|
assert "classification_distribution" in stats
|
||||||
|
assert "body_count_histogram" in stats
|
||||||
|
assert "joint_type_distribution" in stats
|
||||||
|
assert "dof_statistics" in stats
|
||||||
|
assert "geometric_degeneracy" in stats
|
||||||
|
assert "rigidity" in stats
|
||||||
|
|
||||||
|
def test_index_structure(self, tmp_path: Path) -> None:
|
||||||
|
"""index.json has expected format."""
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=15,
|
||||||
|
output_dir=str(tmp_path / "output"),
|
||||||
|
shard_size=10,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
index = json.loads((tmp_path / "output" / "index.json").read_text())
|
||||||
|
assert index["format_version"] == 1
|
||||||
|
assert index["total_assemblies"] == 15
|
||||||
|
assert index["total_shards"] == 2
|
||||||
|
assert "shards" in index
|
||||||
|
for _name, info in index["shards"].items():
|
||||||
|
assert "start_id" in info
|
||||||
|
assert "count" in info
|
||||||
|
|
||||||
|
def test_deterministic_output(self, tmp_path: Path) -> None:
|
||||||
|
"""Same seed produces same results."""
|
||||||
|
for run_dir in ("run1", "run2"):
|
||||||
|
cfg = DatasetConfig(
|
||||||
|
num_assemblies=5,
|
||||||
|
output_dir=str(tmp_path / run_dir),
|
||||||
|
shard_size=5,
|
||||||
|
seed=42,
|
||||||
|
num_workers=1,
|
||||||
|
)
|
||||||
|
DatasetGenerator(cfg).run()
|
||||||
|
|
||||||
|
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
|
||||||
|
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
|
||||||
|
assert s1["total_examples"] == s2["total_examples"]
|
||||||
|
assert s1["classification_distribution"] == s2["classification_distribution"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# CLI integration test
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestCLI:
|
||||||
|
"""Run the script via subprocess."""
|
||||||
|
|
||||||
|
def test_argparse_mode(self, tmp_path: Path) -> None:
|
||||||
|
result = subprocess.run(
|
||||||
|
[
|
||||||
|
sys.executable,
|
||||||
|
"scripts/generate_synthetic.py",
|
||||||
|
"--num-assemblies",
|
||||||
|
"5",
|
||||||
|
"--output-dir",
|
||||||
|
str(tmp_path / "cli_out"),
|
||||||
|
"--shard-size",
|
||||||
|
"5",
|
||||||
|
"--num-workers",
|
||||||
|
"1",
|
||||||
|
"--seed",
|
||||||
|
"42",
|
||||||
|
],
|
||||||
|
capture_output=True,
|
||||||
|
text=True,
|
||||||
|
cwd="/home/developer",
|
||||||
|
timeout=120,
|
||||||
|
env={**os.environ, "PYTHONPATH": "/home/developer"},
|
||||||
|
)
|
||||||
|
assert result.returncode == 0, (
|
||||||
|
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||||
|
)
|
||||||
|
assert (tmp_path / "cli_out" / "index.json").exists()
|
||||||
|
assert (tmp_path / "cli_out" / "stats.json").exists()
|
||||||
682
tests/datagen/test_generator.py
Normal file
682
tests/datagen/test_generator.py
Normal file
@@ -0,0 +1,682 @@
|
|||||||
|
"""Tests for solver.datagen.generator -- synthetic assembly generation."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||||
|
from solver.datagen.types import JointType
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Original generators (chain / rigid / overconstrained)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestChainAssembly:
|
||||||
|
"""generate_chain_assembly produces valid underconstrained chains."""
|
||||||
|
|
||||||
|
def test_returns_three_tuple(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
bodies, joints, _analysis = gen.generate_chain_assembly(4)
|
||||||
|
assert len(bodies) == 4
|
||||||
|
assert len(joints) == 3
|
||||||
|
|
||||||
|
def test_chain_underconstrained(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
_, _, analysis = gen.generate_chain_assembly(4)
|
||||||
|
assert analysis.combinatorial_classification == "underconstrained"
|
||||||
|
|
||||||
|
def test_chain_body_ids(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
bodies, _, _ = gen.generate_chain_assembly(5)
|
||||||
|
ids = [b.body_id for b in bodies]
|
||||||
|
assert ids == [0, 1, 2, 3, 4]
|
||||||
|
|
||||||
|
def test_chain_joint_connectivity(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
_, joints, _ = gen.generate_chain_assembly(4)
|
||||||
|
for i, j in enumerate(joints):
|
||||||
|
assert j.body_a == i
|
||||||
|
assert j.body_b == i + 1
|
||||||
|
|
||||||
|
def test_chain_custom_joint_type(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
_, joints, _ = gen.generate_chain_assembly(
|
||||||
|
3,
|
||||||
|
joint_type=JointType.BALL,
|
||||||
|
)
|
||||||
|
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRigidAssembly:
|
||||||
|
"""generate_rigid_assembly produces rigid assemblies."""
|
||||||
|
|
||||||
|
def test_rigid(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_rigid_assembly(4)
|
||||||
|
assert analysis.is_rigid
|
||||||
|
|
||||||
|
def test_spanning_tree_structure(self) -> None:
|
||||||
|
"""n bodies should have at least n-1 joints (spanning tree)."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_rigid_assembly(5)
|
||||||
|
assert len(joints) >= len(bodies) - 1
|
||||||
|
|
||||||
|
def test_deterministic(self) -> None:
|
||||||
|
"""Same seed produces same results."""
|
||||||
|
g1 = SyntheticAssemblyGenerator(seed=99)
|
||||||
|
g2 = SyntheticAssemblyGenerator(seed=99)
|
||||||
|
_, j1, a1 = g1.generate_rigid_assembly(4)
|
||||||
|
_, j2, a2 = g2.generate_rigid_assembly(4)
|
||||||
|
assert a1.jacobian_rank == a2.jacobian_rank
|
||||||
|
assert len(j1) == len(j2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestOverconstrainedAssembly:
|
||||||
|
"""generate_overconstrained_assembly adds redundant constraints."""
|
||||||
|
|
||||||
|
def test_has_redundant(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_overconstrained_assembly(
|
||||||
|
4,
|
||||||
|
extra_joints=2,
|
||||||
|
)
|
||||||
|
assert analysis.combinatorial_redundant > 0
|
||||||
|
|
||||||
|
def test_extra_joints_added(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints_base, _ = gen.generate_rigid_assembly(4)
|
||||||
|
|
||||||
|
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints_over, _ = gen2.generate_overconstrained_assembly(
|
||||||
|
4,
|
||||||
|
extra_joints=3,
|
||||||
|
)
|
||||||
|
# Overconstrained has base joints + extra
|
||||||
|
assert len(joints_over) > len(joints_base)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# New topology generators
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTreeAssembly:
|
||||||
|
"""generate_tree_assembly produces tree-structured assemblies."""
|
||||||
|
|
||||||
|
def test_body_and_joint_counts(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_tree_assembly(6)
|
||||||
|
assert len(bodies) == 6
|
||||||
|
assert len(joints) == 5 # n - 1
|
||||||
|
|
||||||
|
def test_underconstrained(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_tree_assembly(6)
|
||||||
|
assert analysis.combinatorial_classification == "underconstrained"
|
||||||
|
|
||||||
|
def test_branching_factor(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_tree_assembly(
|
||||||
|
10,
|
||||||
|
branching_factor=2,
|
||||||
|
)
|
||||||
|
assert len(bodies) == 10
|
||||||
|
assert len(joints) == 9
|
||||||
|
|
||||||
|
def test_mixed_joint_types(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
|
||||||
|
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
|
||||||
|
used = {j.joint_type for j in joints}
|
||||||
|
# With 9 joints and 3 types, very likely to use at least 2
|
||||||
|
assert len(used) >= 2
|
||||||
|
|
||||||
|
def test_single_joint_type(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_tree_assembly(
|
||||||
|
5,
|
||||||
|
joint_types=JointType.BALL,
|
||||||
|
)
|
||||||
|
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||||
|
|
||||||
|
def test_sequential_body_ids(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, _, _ = gen.generate_tree_assembly(7)
|
||||||
|
assert [b.body_id for b in bodies] == list(range(7))
|
||||||
|
|
||||||
|
|
||||||
|
class TestLoopAssembly:
|
||||||
|
"""generate_loop_assembly produces closed-loop assemblies."""
|
||||||
|
|
||||||
|
def test_body_and_joint_counts(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_loop_assembly(5)
|
||||||
|
assert len(bodies) == 5
|
||||||
|
assert len(joints) == 5 # n joints for n bodies
|
||||||
|
|
||||||
|
def test_has_redundancy(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_loop_assembly(5)
|
||||||
|
assert analysis.combinatorial_redundant > 0
|
||||||
|
|
||||||
|
def test_wrap_around_connectivity(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_loop_assembly(4)
|
||||||
|
edges = {(j.body_a, j.body_b) for j in joints}
|
||||||
|
assert (0, 1) in edges
|
||||||
|
assert (1, 2) in edges
|
||||||
|
assert (2, 3) in edges
|
||||||
|
assert (3, 0) in edges # wrap-around
|
||||||
|
|
||||||
|
def test_minimum_bodies_error(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
with pytest.raises(ValueError, match="at least 3"):
|
||||||
|
gen.generate_loop_assembly(2)
|
||||||
|
|
||||||
|
def test_mixed_joint_types(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
types = [JointType.REVOLUTE, JointType.FIXED]
|
||||||
|
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
|
||||||
|
used = {j.joint_type for j in joints}
|
||||||
|
assert len(used) >= 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestStarAssembly:
|
||||||
|
"""generate_star_assembly produces hub-and-spoke assemblies."""
|
||||||
|
|
||||||
|
def test_body_and_joint_counts(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_star_assembly(6)
|
||||||
|
assert len(bodies) == 6
|
||||||
|
assert len(joints) == 5 # n - 1
|
||||||
|
|
||||||
|
def test_all_joints_connect_to_hub(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_star_assembly(6)
|
||||||
|
for j in joints:
|
||||||
|
assert j.body_a == 0 or j.body_b == 0
|
||||||
|
|
||||||
|
def test_underconstrained(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_star_assembly(5)
|
||||||
|
assert analysis.combinatorial_classification == "underconstrained"
|
||||||
|
|
||||||
|
def test_minimum_bodies_error(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
with pytest.raises(ValueError, match="at least 2"):
|
||||||
|
gen.generate_star_assembly(1)
|
||||||
|
|
||||||
|
def test_hub_at_origin(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, _, _ = gen.generate_star_assembly(4)
|
||||||
|
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
|
||||||
|
|
||||||
|
|
||||||
|
class TestMixedAssembly:
|
||||||
|
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
|
||||||
|
|
||||||
|
def test_more_joints_than_tree(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, joints, _ = gen.generate_mixed_assembly(
|
||||||
|
8,
|
||||||
|
edge_density=0.3,
|
||||||
|
)
|
||||||
|
assert len(joints) > len(bodies) - 1
|
||||||
|
|
||||||
|
def test_density_zero_is_tree(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||||
|
5,
|
||||||
|
edge_density=0.0,
|
||||||
|
)
|
||||||
|
assert len(joints) == 4 # spanning tree only
|
||||||
|
|
||||||
|
def test_density_validation(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
with pytest.raises(ValueError, match="must be in"):
|
||||||
|
gen.generate_mixed_assembly(5, edge_density=1.5)
|
||||||
|
with pytest.raises(ValueError, match="must be in"):
|
||||||
|
gen.generate_mixed_assembly(5, edge_density=-0.1)
|
||||||
|
|
||||||
|
def test_no_duplicate_edges(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
|
||||||
|
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
|
||||||
|
assert len(edges) == len(set(edges))
|
||||||
|
|
||||||
|
def test_high_density(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||||
|
5,
|
||||||
|
edge_density=1.0,
|
||||||
|
)
|
||||||
|
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||||
|
assert len(joints) == 10
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Axis sampling strategies
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAxisStrategy:
|
||||||
|
"""Axis sampling strategies produce valid unit vectors."""
|
||||||
|
|
||||||
|
def test_cardinal_axis_from_six(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
|
||||||
|
expected = {
|
||||||
|
(1, 0, 0),
|
||||||
|
(-1, 0, 0),
|
||||||
|
(0, 1, 0),
|
||||||
|
(0, -1, 0),
|
||||||
|
(0, 0, 1),
|
||||||
|
(0, 0, -1),
|
||||||
|
}
|
||||||
|
assert axes == expected
|
||||||
|
|
||||||
|
def test_random_axis_unit_norm(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
for _ in range(50):
|
||||||
|
axis = gen._sample_axis("random")
|
||||||
|
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||||
|
|
||||||
|
def test_near_parallel_close_to_base(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
base = np.array([0.0, 0.0, 1.0])
|
||||||
|
for _ in range(50):
|
||||||
|
axis = gen._near_parallel_axis(base)
|
||||||
|
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||||
|
assert np.dot(axis, base) > 0.95
|
||||||
|
|
||||||
|
def test_sample_axis_cardinal(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
axis = gen._sample_axis("cardinal")
|
||||||
|
cardinals = [
|
||||||
|
np.array(v, dtype=float)
|
||||||
|
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
|
||||||
|
]
|
||||||
|
assert any(np.allclose(axis, c) for c in cardinals)
|
||||||
|
|
||||||
|
def test_sample_axis_near_parallel(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=0)
|
||||||
|
axis = gen._sample_axis("near_parallel")
|
||||||
|
z = np.array([0.0, 0.0, 1.0])
|
||||||
|
assert np.dot(axis, z) > 0.95
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Geometric diversity: orientations
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRandomOrientations:
|
||||||
|
"""Bodies should have non-identity orientations."""
|
||||||
|
|
||||||
|
def test_bodies_have_orientations(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, _, _ = gen.generate_tree_assembly(5)
|
||||||
|
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
|
||||||
|
assert non_identity >= 3
|
||||||
|
|
||||||
|
def test_orientations_are_valid_rotations(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
bodies, _, _ = gen.generate_star_assembly(6)
|
||||||
|
for b in bodies:
|
||||||
|
r = b.orientation
|
||||||
|
# R^T R == I
|
||||||
|
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
|
||||||
|
# det(R) == 1
|
||||||
|
assert abs(np.linalg.det(r) - 1.0) < 1e-10
|
||||||
|
|
||||||
|
def test_all_generators_set_orientations(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
# Chain
|
||||||
|
bodies, _, _ = gen.generate_chain_assembly(3)
|
||||||
|
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||||
|
# Loop
|
||||||
|
bodies, _, _ = gen.generate_loop_assembly(4)
|
||||||
|
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||||
|
# Mixed
|
||||||
|
bodies, _, _ = gen.generate_mixed_assembly(4)
|
||||||
|
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Geometric diversity: grounded parameter
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundedParameter:
|
||||||
|
"""Grounded parameter controls ground_body in analysis."""
|
||||||
|
|
||||||
|
def test_chain_grounded_default(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_chain_assembly(4)
|
||||||
|
assert analysis.combinatorial_dof >= 0
|
||||||
|
|
||||||
|
def test_chain_floating(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, analysis = gen.generate_chain_assembly(
|
||||||
|
4,
|
||||||
|
grounded=False,
|
||||||
|
)
|
||||||
|
# Floating: 6 trivial DOF not subtracted by ground
|
||||||
|
assert analysis.combinatorial_dof >= 6
|
||||||
|
|
||||||
|
def test_floating_vs_grounded_dof_difference(self) -> None:
|
||||||
|
gen1 = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
|
||||||
|
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
|
||||||
|
# Floating should have higher DOF due to missing ground constraint
|
||||||
|
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"gen_method",
|
||||||
|
[
|
||||||
|
"generate_chain_assembly",
|
||||||
|
"generate_rigid_assembly",
|
||||||
|
"generate_tree_assembly",
|
||||||
|
"generate_loop_assembly",
|
||||||
|
"generate_star_assembly",
|
||||||
|
"generate_mixed_assembly",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_all_generators_accept_grounded(
|
||||||
|
self,
|
||||||
|
gen_method: str,
|
||||||
|
) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
method = getattr(gen, gen_method)
|
||||||
|
n = 4
|
||||||
|
# Should not raise
|
||||||
|
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
|
||||||
|
method(n, grounded=False)
|
||||||
|
else:
|
||||||
|
method(n, grounded=False)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Geometric diversity: parallel axis injection
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelAxisInjection:
|
||||||
|
"""parallel_axis_prob causes shared axis direction."""
|
||||||
|
|
||||||
|
def test_parallel_axes_similar(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_chain_assembly(
|
||||||
|
6,
|
||||||
|
parallel_axis_prob=1.0,
|
||||||
|
)
|
||||||
|
base = joints[0].axis
|
||||||
|
for j in joints[1:]:
|
||||||
|
# Near-parallel: |dot| close to 1
|
||||||
|
assert abs(np.dot(j.axis, base)) > 0.9
|
||||||
|
|
||||||
|
def test_zero_prob_no_forced_parallel(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_chain_assembly(
|
||||||
|
6,
|
||||||
|
parallel_axis_prob=0.0,
|
||||||
|
)
|
||||||
|
base = joints[0].axis
|
||||||
|
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
|
||||||
|
# With 5 random axes, extremely unlikely all are parallel
|
||||||
|
assert min(dots) < 0.95
|
||||||
|
|
||||||
|
def test_parallel_on_loop(self) -> None:
|
||||||
|
"""Parallel axes on a loop assembly."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_loop_assembly(
|
||||||
|
5,
|
||||||
|
parallel_axis_prob=1.0,
|
||||||
|
)
|
||||||
|
base = joints[0].axis
|
||||||
|
for j in joints[1:]:
|
||||||
|
assert abs(np.dot(j.axis, base)) > 0.9
|
||||||
|
|
||||||
|
def test_parallel_on_star(self) -> None:
|
||||||
|
"""Parallel axes on a star assembly."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
_, joints, _ = gen.generate_star_assembly(
|
||||||
|
5,
|
||||||
|
parallel_axis_prob=1.0,
|
||||||
|
)
|
||||||
|
base = joints[0].axis
|
||||||
|
for j in joints[1:]:
|
||||||
|
assert abs(np.dot(j.axis, base)) > 0.9
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Complexity tiers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestComplexityTiers:
|
||||||
|
"""Complexity tier parameter on batch generation."""
|
||||||
|
|
||||||
|
def test_simple_range(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(20, complexity_tier="simple")
|
||||||
|
lo, hi = COMPLEXITY_RANGES["simple"]
|
||||||
|
for ex in batch:
|
||||||
|
assert lo <= ex["n_bodies"] < hi
|
||||||
|
|
||||||
|
def test_medium_range(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(20, complexity_tier="medium")
|
||||||
|
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||||
|
for ex in batch:
|
||||||
|
assert lo <= ex["n_bodies"] < hi
|
||||||
|
|
||||||
|
def test_complex_range(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(3, complexity_tier="complex")
|
||||||
|
lo, hi = COMPLEXITY_RANGES["complex"]
|
||||||
|
for ex in batch:
|
||||||
|
assert lo <= ex["n_bodies"] < hi
|
||||||
|
|
||||||
|
def test_tier_overrides_range(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(
|
||||||
|
10,
|
||||||
|
n_bodies_range=(2, 3),
|
||||||
|
complexity_tier="medium",
|
||||||
|
)
|
||||||
|
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||||
|
for ex in batch:
|
||||||
|
assert lo <= ex["n_bodies"] < hi
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Training batch
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestTrainingBatch:
|
||||||
|
"""generate_training_batch produces well-structured examples."""
|
||||||
|
|
||||||
|
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||||
|
"example_id",
|
||||||
|
"generator_type",
|
||||||
|
"grounded",
|
||||||
|
"n_bodies",
|
||||||
|
"n_joints",
|
||||||
|
"body_positions",
|
||||||
|
"body_orientations",
|
||||||
|
"joints",
|
||||||
|
"joint_labels",
|
||||||
|
"labels",
|
||||||
|
"assembly_classification",
|
||||||
|
"is_rigid",
|
||||||
|
"is_minimally_rigid",
|
||||||
|
"internal_dof",
|
||||||
|
"geometric_degeneracies",
|
||||||
|
}
|
||||||
|
|
||||||
|
VALID_GEN_TYPES: ClassVar[set[str]] = {
|
||||||
|
"chain",
|
||||||
|
"rigid",
|
||||||
|
"overconstrained",
|
||||||
|
"tree",
|
||||||
|
"loop",
|
||||||
|
"star",
|
||||||
|
"mixed",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_batch_size(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(20)
|
||||||
|
assert len(batch) == 20
|
||||||
|
|
||||||
|
def test_example_keys(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(10)
|
||||||
|
for ex in batch:
|
||||||
|
assert set(ex.keys()) == self.EXPECTED_KEYS
|
||||||
|
|
||||||
|
def test_example_ids_sequential(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(15)
|
||||||
|
assert [ex["example_id"] for ex in batch] == list(range(15))
|
||||||
|
|
||||||
|
def test_generator_type_valid(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(50)
|
||||||
|
for ex in batch:
|
||||||
|
assert ex["generator_type"] in self.VALID_GEN_TYPES
|
||||||
|
|
||||||
|
def test_generator_type_diversity(self) -> None:
|
||||||
|
"""100-sample batch should use at least 5 of 7 generator types."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(100)
|
||||||
|
types = {ex["generator_type"] for ex in batch}
|
||||||
|
assert len(types) >= 5
|
||||||
|
|
||||||
|
def test_default_body_range(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(30)
|
||||||
|
for ex in batch:
|
||||||
|
# default (3, 8), but loop/star may clamp
|
||||||
|
assert 2 <= ex["n_bodies"] <= 7
|
||||||
|
|
||||||
|
def test_joint_label_consistency(self) -> None:
|
||||||
|
"""independent + redundant == total for every joint."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(30)
|
||||||
|
for ex in batch:
|
||||||
|
for label in ex["joint_labels"].values():
|
||||||
|
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||||
|
assert total == label["total_constraints"]
|
||||||
|
|
||||||
|
def test_body_orientations_present(self) -> None:
|
||||||
|
"""Each example includes body_orientations as 3x3 lists."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(10)
|
||||||
|
for ex in batch:
|
||||||
|
orients = ex["body_orientations"]
|
||||||
|
assert len(orients) == ex["n_bodies"]
|
||||||
|
for o in orients:
|
||||||
|
assert len(o) == 3
|
||||||
|
assert len(o[0]) == 3
|
||||||
|
|
||||||
|
def test_labels_structure(self) -> None:
|
||||||
|
"""Each example has labels dict with expected sub-keys."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(10)
|
||||||
|
for ex in batch:
|
||||||
|
labels = ex["labels"]
|
||||||
|
assert "per_constraint" in labels
|
||||||
|
assert "per_joint" in labels
|
||||||
|
assert "per_body" in labels
|
||||||
|
assert "assembly" in labels
|
||||||
|
|
||||||
|
def test_grounded_field_present(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(10)
|
||||||
|
for ex in batch:
|
||||||
|
assert isinstance(ex["grounded"], bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Batch grounded ratio
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchGroundedRatio:
|
||||||
|
"""grounded_ratio controls the mix in batch generation."""
|
||||||
|
|
||||||
|
def test_all_grounded(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
|
||||||
|
assert all(ex["grounded"] for ex in batch)
|
||||||
|
|
||||||
|
def test_none_grounded(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
|
||||||
|
assert not any(ex["grounded"] for ex in batch)
|
||||||
|
|
||||||
|
def test_mixed_ratio(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
|
||||||
|
grounded_count = sum(1 for ex in batch if ex["grounded"])
|
||||||
|
# With 100 samples and p=0.5, should be roughly 50 +/- 20
|
||||||
|
assert 20 < grounded_count < 80
|
||||||
|
|
||||||
|
def test_batch_axis_strategy_cardinal(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(
|
||||||
|
10,
|
||||||
|
axis_strategy="cardinal",
|
||||||
|
)
|
||||||
|
assert len(batch) == 10
|
||||||
|
|
||||||
|
def test_batch_parallel_axis_prob(self) -> None:
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(
|
||||||
|
10,
|
||||||
|
parallel_axis_prob=0.5,
|
||||||
|
)
|
||||||
|
assert len(batch) == 10
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Seed reproducibility
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeedReproducibility:
|
||||||
|
"""Different seeds produce different results."""
|
||||||
|
|
||||||
|
def test_different_seeds_differ(self) -> None:
|
||||||
|
g1 = SyntheticAssemblyGenerator(seed=1)
|
||||||
|
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||||
|
b1 = g1.generate_training_batch(
|
||||||
|
batch_size=5,
|
||||||
|
n_bodies_range=(3, 6),
|
||||||
|
)
|
||||||
|
b2 = g2.generate_training_batch(
|
||||||
|
batch_size=5,
|
||||||
|
n_bodies_range=(3, 6),
|
||||||
|
)
|
||||||
|
c1 = [ex["assembly_classification"] for ex in b1]
|
||||||
|
c2 = [ex["assembly_classification"] for ex in b2]
|
||||||
|
r1 = [ex["is_rigid"] for ex in b1]
|
||||||
|
r2 = [ex["is_rigid"] for ex in b2]
|
||||||
|
assert c1 != c2 or r1 != r2
|
||||||
|
|
||||||
|
def test_same_seed_identical(self) -> None:
|
||||||
|
g1 = SyntheticAssemblyGenerator(seed=123)
|
||||||
|
g2 = SyntheticAssemblyGenerator(seed=123)
|
||||||
|
b1, j1, _ = g1.generate_tree_assembly(5)
|
||||||
|
b2, j2, _ = g2.generate_tree_assembly(5)
|
||||||
|
for a, b in zip(b1, b2, strict=True):
|
||||||
|
np.testing.assert_array_almost_equal(a.position, b.position)
|
||||||
|
assert len(j1) == len(j2)
|
||||||
267
tests/datagen/test_jacobian.py
Normal file
267
tests/datagen/test_jacobian.py
Normal file
@@ -0,0 +1,267 @@
|
|||||||
|
"""Tests for solver.datagen.jacobian -- Jacobian rank verification."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from solver.datagen.jacobian import JacobianVerifier
|
||||||
|
from solver.datagen.types import Joint, JointType, RigidBody
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Fixtures
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _two_bodies() -> list[RigidBody]:
|
||||||
|
return [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _three_bodies() -> list[RigidBody]:
|
||||||
|
return [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestJacobianShape:
|
||||||
|
"""Verify Jacobian matrix dimensions for each joint type."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"joint_type,expected_rows",
|
||||||
|
[
|
||||||
|
(JointType.FIXED, 6),
|
||||||
|
(JointType.REVOLUTE, 5),
|
||||||
|
(JointType.CYLINDRICAL, 4),
|
||||||
|
(JointType.SLIDER, 5),
|
||||||
|
(JointType.BALL, 3),
|
||||||
|
(JointType.PLANAR, 3),
|
||||||
|
(JointType.SCREW, 5),
|
||||||
|
(JointType.UNIVERSAL, 4),
|
||||||
|
(JointType.PARALLEL, 3),
|
||||||
|
(JointType.PERPENDICULAR, 1),
|
||||||
|
(JointType.DISTANCE, 1),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_row_count(self, joint_type: JointType, expected_rows: int) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
joint = Joint(
|
||||||
|
joint_id=0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=joint_type,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
)
|
||||||
|
n_added = v.add_joint_constraints(joint)
|
||||||
|
assert n_added == expected_rows
|
||||||
|
|
||||||
|
j = v.get_jacobian()
|
||||||
|
assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols
|
||||||
|
|
||||||
|
|
||||||
|
class TestNumericalRank:
|
||||||
|
"""Numerical rank checks for known configurations."""
|
||||||
|
|
||||||
|
def test_fixed_joint_rank_six(self) -> None:
|
||||||
|
"""Fixed joint between 2 bodies: rank = 6."""
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
j = Joint(
|
||||||
|
joint_id=0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
)
|
||||||
|
v.add_joint_constraints(j)
|
||||||
|
assert v.numerical_rank() == 6
|
||||||
|
|
||||||
|
def test_revolute_joint_rank_five(self) -> None:
|
||||||
|
"""Revolute joint between 2 bodies: rank = 5."""
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
j = Joint(
|
||||||
|
joint_id=0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
)
|
||||||
|
v.add_joint_constraints(j)
|
||||||
|
assert v.numerical_rank() == 5
|
||||||
|
|
||||||
|
def test_ball_joint_rank_three(self) -> None:
|
||||||
|
"""Ball joint between 2 bodies: rank = 3."""
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
j = Joint(
|
||||||
|
joint_id=0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.BALL,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
)
|
||||||
|
v.add_joint_constraints(j)
|
||||||
|
assert v.numerical_rank() == 3
|
||||||
|
|
||||||
|
def test_empty_jacobian_rank_zero(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
assert v.numerical_rank() == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestParallelAxesDegeneracy:
|
||||||
|
"""Parallel revolute axes create geometric dependencies."""
|
||||||
|
|
||||||
|
def _four_body_loop(self) -> list[RigidBody]:
|
||||||
|
return [
|
||||||
|
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||||
|
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||||
|
RigidBody(2, position=np.array([2.0, 2.0, 0.0])),
|
||||||
|
RigidBody(3, position=np.array([0.0, 2.0, 0.0])),
|
||||||
|
]
|
||||||
|
|
||||||
|
def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]:
|
||||||
|
pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])]
|
||||||
|
return [
|
||||||
|
Joint(
|
||||||
|
joint_id=i,
|
||||||
|
body_a=a,
|
||||||
|
body_b=b,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array(anc, dtype=float),
|
||||||
|
anchor_b=np.array(anc, dtype=float),
|
||||||
|
axis=axes[i],
|
||||||
|
)
|
||||||
|
for i, (a, b, anc) in enumerate(pairs)
|
||||||
|
]
|
||||||
|
|
||||||
|
def test_parallel_has_lower_rank(self) -> None:
|
||||||
|
"""4-body closed loop: all-parallel revolute axes produce lower
|
||||||
|
Jacobian rank than mixed axes due to geometric dependency."""
|
||||||
|
bodies = self._four_body_loop()
|
||||||
|
z_axis = np.array([0.0, 0.0, 1.0])
|
||||||
|
|
||||||
|
# All axes parallel to Z
|
||||||
|
v_par = JacobianVerifier(bodies)
|
||||||
|
for j in self._loop_joints([z_axis] * 4):
|
||||||
|
v_par.add_joint_constraints(j)
|
||||||
|
rank_par = v_par.numerical_rank()
|
||||||
|
|
||||||
|
# Mixed axes
|
||||||
|
mixed = [
|
||||||
|
np.array([0.0, 0.0, 1.0]),
|
||||||
|
np.array([0.0, 1.0, 0.0]),
|
||||||
|
np.array([0.0, 0.0, 1.0]),
|
||||||
|
np.array([1.0, 0.0, 0.0]),
|
||||||
|
]
|
||||||
|
v_mix = JacobianVerifier(bodies)
|
||||||
|
for j in self._loop_joints(mixed):
|
||||||
|
v_mix.add_joint_constraints(j)
|
||||||
|
rank_mix = v_mix.numerical_rank()
|
||||||
|
|
||||||
|
assert rank_par < rank_mix
|
||||||
|
|
||||||
|
|
||||||
|
class TestFindDependencies:
|
||||||
|
"""Dependency detection."""
|
||||||
|
|
||||||
|
def test_fixed_joint_no_dependencies(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
j = Joint(
|
||||||
|
joint_id=0,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
)
|
||||||
|
v.add_joint_constraints(j)
|
||||||
|
assert v.find_dependencies() == []
|
||||||
|
|
||||||
|
def test_duplicate_fixed_has_dependencies(self) -> None:
|
||||||
|
"""Two fixed joints on same pair: second is fully dependent."""
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
for jid in range(2):
|
||||||
|
v.add_joint_constraints(
|
||||||
|
Joint(
|
||||||
|
joint_id=jid,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.FIXED,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
deps = v.find_dependencies()
|
||||||
|
assert len(deps) == 6 # Second fixed joint entirely redundant
|
||||||
|
|
||||||
|
def test_empty_no_dependencies(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
assert v.find_dependencies() == []
|
||||||
|
|
||||||
|
|
||||||
|
class TestRowLabels:
|
||||||
|
"""Row label metadata."""
|
||||||
|
|
||||||
|
def test_labels_match_rows(self) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
j = Joint(
|
||||||
|
joint_id=7,
|
||||||
|
body_a=0,
|
||||||
|
body_b=1,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([0.0, 0.0, 1.0]),
|
||||||
|
)
|
||||||
|
v.add_joint_constraints(j)
|
||||||
|
assert len(v.row_labels) == 5
|
||||||
|
assert all(lab["joint_id"] == 7 for lab in v.row_labels)
|
||||||
|
|
||||||
|
|
||||||
|
class TestPerpendicularPair:
|
||||||
|
"""Internal _perpendicular_pair utility."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"axis",
|
||||||
|
[
|
||||||
|
np.array([1.0, 0.0, 0.0]),
|
||||||
|
np.array([0.0, 1.0, 0.0]),
|
||||||
|
np.array([0.0, 0.0, 1.0]),
|
||||||
|
np.array([1.0, 1.0, 1.0]) / np.sqrt(3),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_orthonormal(self, axis: np.ndarray) -> None:
|
||||||
|
bodies = _two_bodies()
|
||||||
|
v = JacobianVerifier(bodies)
|
||||||
|
t1, t2 = v._perpendicular_pair(axis)
|
||||||
|
|
||||||
|
# All unit length
|
||||||
|
np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12)
|
||||||
|
np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12)
|
||||||
|
|
||||||
|
# Mutually perpendicular
|
||||||
|
np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12)
|
||||||
|
np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12)
|
||||||
|
np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12)
|
||||||
346
tests/datagen/test_labeling.py
Normal file
346
tests/datagen/test_labeling.py
Normal file
@@ -0,0 +1,346 @@
|
|||||||
|
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from solver.datagen.labeling import (
|
||||||
|
label_assembly,
|
||||||
|
)
|
||||||
|
from solver.datagen.types import Joint, JointType, RigidBody
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
|
||||||
|
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_joint(
|
||||||
|
jid: int,
|
||||||
|
a: int,
|
||||||
|
b: int,
|
||||||
|
jtype: JointType,
|
||||||
|
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
|
||||||
|
) -> Joint:
|
||||||
|
return Joint(
|
||||||
|
joint_id=jid,
|
||||||
|
body_a=a,
|
||||||
|
body_b=b,
|
||||||
|
joint_type=jtype,
|
||||||
|
anchor_a=np.zeros(3),
|
||||||
|
anchor_b=np.zeros(3),
|
||||||
|
axis=np.array(axis),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-constraint labels
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstraintLabels:
|
||||||
|
"""Per-constraint labels combine pebble game and Jacobian results."""
|
||||||
|
|
||||||
|
def test_fixed_joint_all_independent(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert len(labels.per_constraint) == 6
|
||||||
|
for cl in labels.per_constraint:
|
||||||
|
assert cl.pebble_independent is True
|
||||||
|
assert cl.jacobian_independent is True
|
||||||
|
|
||||||
|
def test_revolute_joint_all_independent(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert len(labels.per_constraint) == 5
|
||||||
|
for cl in labels.per_constraint:
|
||||||
|
assert cl.pebble_independent is True
|
||||||
|
assert cl.jacobian_independent is True
|
||||||
|
|
||||||
|
def test_chain_constraint_count(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert len(labels.per_constraint) == 10 # 5 + 5
|
||||||
|
|
||||||
|
def test_constraint_joint_ids(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.BALL),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
|
||||||
|
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
|
||||||
|
assert len(j0_constraints) == 5 # revolute
|
||||||
|
assert len(j1_constraints) == 3 # ball
|
||||||
|
|
||||||
|
def test_overconstrained_has_pebble_redundant(self) -> None:
|
||||||
|
"""Triangle with revolute joints: some constraints redundant."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
|
||||||
|
assert pebble_redundant > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-joint labels
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestJointLabels:
|
||||||
|
"""Per-joint aggregated labels."""
|
||||||
|
|
||||||
|
def test_fixed_joint_counts(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert len(labels.per_joint) == 1
|
||||||
|
jl = labels.per_joint[0]
|
||||||
|
assert jl.joint_id == 0
|
||||||
|
assert jl.independent_count == 6
|
||||||
|
assert jl.redundant_count == 0
|
||||||
|
assert jl.total == 6
|
||||||
|
|
||||||
|
def test_overconstrained_has_redundant_joints(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
|
||||||
|
assert total_redundant > 0
|
||||||
|
|
||||||
|
def test_joint_total_equals_dof(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.BALL)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
jl = labels.per_joint[0]
|
||||||
|
assert jl.total == 3 # ball has 3 DOF
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Per-body DOF labels
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestBodyDofLabels:
|
||||||
|
"""Per-body DOF signatures from nullspace projection."""
|
||||||
|
|
||||||
|
def test_fixed_joint_grounded_both_zero(self) -> None:
|
||||||
|
"""Two bodies + fixed joint + grounded: both fully constrained."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
for bl in labels.per_body:
|
||||||
|
assert bl.translational_dof == 0
|
||||||
|
assert bl.rotational_dof == 0
|
||||||
|
|
||||||
|
def test_revolute_has_rotational_dof(self) -> None:
|
||||||
|
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||||
|
# Revolute allows 1 rotation DOF
|
||||||
|
assert b1.rotational_dof >= 1
|
||||||
|
|
||||||
|
def test_dof_bounds(self) -> None:
|
||||||
|
"""All DOF values should be in [0, 3]."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
for bl in labels.per_body:
|
||||||
|
assert 0 <= bl.translational_dof <= 3
|
||||||
|
assert 0 <= bl.rotational_dof <= 3
|
||||||
|
|
||||||
|
def test_floating_more_dof_than_grounded(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
grounded = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
floating = label_assembly(bodies, joints, ground_body=None)
|
||||||
|
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
|
||||||
|
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
|
||||||
|
assert f_total > g_total
|
||||||
|
|
||||||
|
def test_grounded_body_zero_dof(self) -> None:
|
||||||
|
"""The grounded body should have 0 DOF."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
b0 = next(b for b in labels.per_body if b.body_id == 0)
|
||||||
|
assert b0.translational_dof == 0
|
||||||
|
assert b0.rotational_dof == 0
|
||||||
|
|
||||||
|
def test_body_count_matches(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.BALL),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert len(labels.per_body) == 3
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Assembly label
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyLabel:
|
||||||
|
"""Assembly-wide summary labels."""
|
||||||
|
|
||||||
|
def test_underconstrained_chain(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert labels.assembly.classification == "underconstrained"
|
||||||
|
assert labels.assembly.is_rigid is False
|
||||||
|
|
||||||
|
def test_well_constrained(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert labels.assembly.classification == "well-constrained"
|
||||||
|
assert labels.assembly.is_rigid is True
|
||||||
|
assert labels.assembly.is_minimally_rigid is True
|
||||||
|
|
||||||
|
def test_overconstrained(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||||
|
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert labels.assembly.redundant_count > 0
|
||||||
|
|
||||||
|
def test_has_degeneracy_with_parallel_axes(self) -> None:
|
||||||
|
"""Parallel revolute axes in a loop create geometric degeneracy."""
|
||||||
|
z_axis = (0.0, 0.0, 1.0)
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
|
||||||
|
joints = [
|
||||||
|
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
|
||||||
|
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
|
||||||
|
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
|
||||||
|
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
|
||||||
|
]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
assert labels.assembly.has_degeneracy is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Serialization
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestToDict:
|
||||||
|
"""to_dict produces JSON-serializable output."""
|
||||||
|
|
||||||
|
def test_top_level_keys(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
d = labels.to_dict()
|
||||||
|
assert set(d.keys()) == {
|
||||||
|
"per_constraint",
|
||||||
|
"per_joint",
|
||||||
|
"per_body",
|
||||||
|
"assembly",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_per_constraint_keys(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
d = labels.to_dict()
|
||||||
|
for item in d["per_constraint"]:
|
||||||
|
assert set(item.keys()) == {
|
||||||
|
"joint_id",
|
||||||
|
"constraint_idx",
|
||||||
|
"pebble_independent",
|
||||||
|
"jacobian_independent",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_assembly_keys(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
d = labels.to_dict()
|
||||||
|
assert set(d["assembly"].keys()) == {
|
||||||
|
"classification",
|
||||||
|
"total_dof",
|
||||||
|
"redundant_count",
|
||||||
|
"is_rigid",
|
||||||
|
"is_minimally_rigid",
|
||||||
|
"has_degeneracy",
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_json_serializable(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
d = labels.to_dict()
|
||||||
|
# Should not raise
|
||||||
|
serialized = json.dumps(d)
|
||||||
|
assert isinstance(serialized, str)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Edge cases
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestLabelAssemblyEdgeCases:
|
||||||
|
"""Edge cases for label_assembly."""
|
||||||
|
|
||||||
|
def test_no_joints(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
labels = label_assembly(bodies, [], ground_body=0)
|
||||||
|
assert len(labels.per_constraint) == 0
|
||||||
|
assert len(labels.per_joint) == 0
|
||||||
|
assert labels.assembly.classification == "underconstrained"
|
||||||
|
# Non-ground body should be fully free
|
||||||
|
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||||
|
assert b1.translational_dof == 3
|
||||||
|
assert b1.rotational_dof == 3
|
||||||
|
|
||||||
|
def test_no_joints_floating(self) -> None:
|
||||||
|
bodies = _make_bodies((0, 0, 0))
|
||||||
|
labels = label_assembly(bodies, [], ground_body=None)
|
||||||
|
assert len(labels.per_body) == 1
|
||||||
|
assert labels.per_body[0].translational_dof == 3
|
||||||
|
assert labels.per_body[0].rotational_dof == 3
|
||||||
|
|
||||||
|
def test_analysis_embedded(self) -> None:
|
||||||
|
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
|
||||||
|
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||||
|
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||||
|
labels = label_assembly(bodies, joints, ground_body=0)
|
||||||
|
analysis = labels.analysis
|
||||||
|
assert hasattr(analysis, "combinatorial_classification")
|
||||||
|
assert hasattr(analysis, "jacobian_rank")
|
||||||
|
assert hasattr(analysis, "is_rigid")
|
||||||
206
tests/datagen/test_pebble_game.py
Normal file
206
tests/datagen/test_pebble_game.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from solver.datagen.pebble_game import PebbleGame3D
|
||||||
|
from solver.datagen.types import Joint, JointType
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint:
|
||||||
|
"""Shorthand for a revolute joint between bodies *a* and *b*."""
|
||||||
|
if axis is None:
|
||||||
|
axis = np.array([0.0, 0.0, 1.0])
|
||||||
|
return Joint(
|
||||||
|
joint_id=jid,
|
||||||
|
body_a=a,
|
||||||
|
body_b=b,
|
||||||
|
joint_type=JointType.REVOLUTE,
|
||||||
|
axis=axis,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _fixed(jid: int, a: int, b: int) -> Joint:
|
||||||
|
return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddBody:
|
||||||
|
"""Body registration basics."""
|
||||||
|
|
||||||
|
def test_single_body_six_pebbles(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_body(0)
|
||||||
|
assert pg.state.free_pebbles[0] == 6
|
||||||
|
|
||||||
|
def test_duplicate_body_no_op(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_body(0)
|
||||||
|
pg.add_body(0)
|
||||||
|
assert pg.state.free_pebbles[0] == 6
|
||||||
|
|
||||||
|
def test_multiple_bodies(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
for i in range(5):
|
||||||
|
pg.add_body(i)
|
||||||
|
assert pg.get_dof() == 30 # 5 * 6
|
||||||
|
|
||||||
|
|
||||||
|
class TestAddJoint:
|
||||||
|
"""Joint insertion and DOF accounting."""
|
||||||
|
|
||||||
|
def test_revolute_removes_five_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
results = pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
assert len(results) == 5 # 5 scalar constraints
|
||||||
|
assert all(r["independent"] for r in results)
|
||||||
|
# 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles
|
||||||
|
assert pg.get_dof() == 7
|
||||||
|
|
||||||
|
def test_fixed_removes_six_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
results = pg.add_joint(_fixed(0, 0, 1))
|
||||||
|
assert len(results) == 6
|
||||||
|
assert all(r["independent"] for r in results)
|
||||||
|
assert pg.get_dof() == 6
|
||||||
|
|
||||||
|
def test_ball_removes_three_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL)
|
||||||
|
results = pg.add_joint(j)
|
||||||
|
assert len(results) == 3
|
||||||
|
assert all(r["independent"] for r in results)
|
||||||
|
assert pg.get_dof() == 9
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoBodiesRevolute:
|
||||||
|
"""Two bodies connected by a revolute -- demo scenario 1."""
|
||||||
|
|
||||||
|
def test_internal_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
# Total DOF = 7, internal = 7 - 6 = 1
|
||||||
|
assert pg.get_internal_dof() == 1
|
||||||
|
|
||||||
|
def test_not_rigid(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
assert not pg.is_rigid()
|
||||||
|
|
||||||
|
def test_classification(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
assert pg.classify_assembly() == "underconstrained"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTwoBodiesFixed:
|
||||||
|
"""Two bodies + fixed joint -- demo scenario 2."""
|
||||||
|
|
||||||
|
def test_zero_internal_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_fixed(0, 0, 1))
|
||||||
|
assert pg.get_internal_dof() == 0
|
||||||
|
|
||||||
|
def test_rigid(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_fixed(0, 0, 1))
|
||||||
|
assert pg.is_rigid()
|
||||||
|
|
||||||
|
def test_well_constrained(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_fixed(0, 0, 1))
|
||||||
|
assert pg.classify_assembly() == "well-constrained"
|
||||||
|
|
||||||
|
|
||||||
|
class TestTriangleRevolute:
|
||||||
|
"""Triangle of 3 bodies with revolute joints -- demo scenario 3."""
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def pg(self) -> PebbleGame3D:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
pg.add_joint(_revolute(1, 1, 2))
|
||||||
|
pg.add_joint(_revolute(2, 2, 0))
|
||||||
|
return pg
|
||||||
|
|
||||||
|
def test_has_redundant_edges(self, pg: PebbleGame3D) -> None:
|
||||||
|
assert pg.get_redundant_count() > 0
|
||||||
|
|
||||||
|
def test_classification_overconstrained(self, pg: PebbleGame3D) -> None:
|
||||||
|
# 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed)
|
||||||
|
assert pg.classify_assembly() in ("overconstrained", "mixed")
|
||||||
|
|
||||||
|
def test_rigid(self, pg: PebbleGame3D) -> None:
|
||||||
|
assert pg.is_rigid()
|
||||||
|
|
||||||
|
|
||||||
|
class TestChainNotRigid:
|
||||||
|
"""A serial chain of 4 bodies with revolute joints is never rigid."""
|
||||||
|
|
||||||
|
def test_chain_underconstrained(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
for i in range(3):
|
||||||
|
pg.add_joint(_revolute(i, i, i + 1))
|
||||||
|
assert not pg.is_rigid()
|
||||||
|
assert pg.classify_assembly() == "underconstrained"
|
||||||
|
|
||||||
|
def test_chain_internal_dof(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
for i in range(3):
|
||||||
|
pg.add_joint(_revolute(i, i, i + 1))
|
||||||
|
# 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3
|
||||||
|
assert pg.get_internal_dof() == 3
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeResults:
|
||||||
|
"""Result dicts returned by add_joint."""
|
||||||
|
|
||||||
|
def test_result_keys(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
results = pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"}
|
||||||
|
for r in results:
|
||||||
|
assert set(r.keys()) == expected_keys
|
||||||
|
|
||||||
|
def test_edge_ids_sequential(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
r1 = pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
r2 = pg.add_joint(_revolute(1, 1, 2))
|
||||||
|
all_ids = [r["edge_id"] for r in r1 + r2]
|
||||||
|
assert all_ids == list(range(10))
|
||||||
|
|
||||||
|
def test_dof_remaining_monotonic(self) -> None:
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
results = pg.add_joint(_revolute(0, 0, 1))
|
||||||
|
dofs = [r["dof_remaining"] for r in results]
|
||||||
|
# Should be non-increasing (each independent edge removes a pebble)
|
||||||
|
for a, b in itertools.pairwise(dofs):
|
||||||
|
assert a >= b
|
||||||
|
|
||||||
|
|
||||||
|
class TestGroundedClassification:
|
||||||
|
"""classify_assembly with grounded=True."""
|
||||||
|
|
||||||
|
def test_grounded_baseline_zero(self) -> None:
|
||||||
|
"""With grounded=True the baseline is 0 (not 6)."""
|
||||||
|
pg = PebbleGame3D()
|
||||||
|
pg.add_joint(_fixed(0, 0, 1))
|
||||||
|
# Ungrounded: well-constrained (6 pebbles = baseline 6)
|
||||||
|
assert pg.classify_assembly(grounded=False) == "well-constrained"
|
||||||
|
# Grounded: the 6 remaining pebbles on body 1 exceed baseline 0,
|
||||||
|
# so the raw pebble game (without a virtual ground body) sees this
|
||||||
|
# as underconstrained. The analysis function handles this properly
|
||||||
|
# by adding a virtual ground body.
|
||||||
|
assert pg.classify_assembly(grounded=True) == "underconstrained"
|
||||||
163
tests/datagen/test_types.py
Normal file
163
tests/datagen/test_types.py
Normal file
@@ -0,0 +1,163 @@
|
|||||||
|
"""Tests for solver.datagen.types -- shared data types."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import ClassVar
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from solver.datagen.types import (
|
||||||
|
ConstraintAnalysis,
|
||||||
|
Joint,
|
||||||
|
JointType,
|
||||||
|
PebbleState,
|
||||||
|
RigidBody,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJointType:
|
||||||
|
"""JointType enum construction and DOF values."""
|
||||||
|
|
||||||
|
EXPECTED_DOF: ClassVar[dict[str, int]] = {
|
||||||
|
"FIXED": 6,
|
||||||
|
"REVOLUTE": 5,
|
||||||
|
"CYLINDRICAL": 4,
|
||||||
|
"SLIDER": 5,
|
||||||
|
"BALL": 3,
|
||||||
|
"PLANAR": 3,
|
||||||
|
"SCREW": 5,
|
||||||
|
"UNIVERSAL": 4,
|
||||||
|
"PARALLEL": 3,
|
||||||
|
"PERPENDICULAR": 1,
|
||||||
|
"DISTANCE": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_member_count(self) -> None:
|
||||||
|
assert len(JointType) == 11
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
|
||||||
|
def test_dof_values(self, name: str, dof: int) -> None:
|
||||||
|
assert JointType[name].dof == dof
|
||||||
|
|
||||||
|
def test_access_by_name(self) -> None:
|
||||||
|
assert JointType["REVOLUTE"] is JointType.REVOLUTE
|
||||||
|
|
||||||
|
def test_value_is_tuple(self) -> None:
|
||||||
|
assert JointType.REVOLUTE.value == (1, 5)
|
||||||
|
assert JointType.REVOLUTE.dof == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestRigidBody:
|
||||||
|
"""RigidBody dataclass defaults and construction."""
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
body = RigidBody(body_id=0)
|
||||||
|
np.testing.assert_array_equal(body.position, np.zeros(3))
|
||||||
|
np.testing.assert_array_equal(body.orientation, np.eye(3))
|
||||||
|
assert body.local_anchors == {}
|
||||||
|
|
||||||
|
def test_custom_position(self) -> None:
|
||||||
|
pos = np.array([1.0, 2.0, 3.0])
|
||||||
|
body = RigidBody(body_id=7, position=pos)
|
||||||
|
np.testing.assert_array_equal(body.position, pos)
|
||||||
|
assert body.body_id == 7
|
||||||
|
|
||||||
|
def test_local_anchors_mutable(self) -> None:
|
||||||
|
body = RigidBody(body_id=0)
|
||||||
|
body.local_anchors["top"] = np.array([0.0, 0.0, 1.0])
|
||||||
|
assert "top" in body.local_anchors
|
||||||
|
|
||||||
|
def test_default_factory_isolation(self) -> None:
|
||||||
|
"""Each instance gets its own default containers."""
|
||||||
|
b1 = RigidBody(body_id=0)
|
||||||
|
b2 = RigidBody(body_id=1)
|
||||||
|
b1.local_anchors["x"] = np.zeros(3)
|
||||||
|
assert "x" not in b2.local_anchors
|
||||||
|
|
||||||
|
|
||||||
|
class TestJoint:
|
||||||
|
"""Joint dataclass defaults and construction."""
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE)
|
||||||
|
np.testing.assert_array_equal(j.anchor_a, np.zeros(3))
|
||||||
|
np.testing.assert_array_equal(j.anchor_b, np.zeros(3))
|
||||||
|
np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0]))
|
||||||
|
assert j.pitch == 0.0
|
||||||
|
|
||||||
|
def test_full_construction(self) -> None:
|
||||||
|
j = Joint(
|
||||||
|
joint_id=5,
|
||||||
|
body_a=2,
|
||||||
|
body_b=3,
|
||||||
|
joint_type=JointType.SCREW,
|
||||||
|
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||||
|
anchor_b=np.array([2.0, 0.0, 0.0]),
|
||||||
|
axis=np.array([1.0, 0.0, 0.0]),
|
||||||
|
pitch=0.5,
|
||||||
|
)
|
||||||
|
assert j.joint_id == 5
|
||||||
|
assert j.joint_type is JointType.SCREW
|
||||||
|
assert j.pitch == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
class TestPebbleState:
|
||||||
|
"""PebbleState dataclass defaults."""
|
||||||
|
|
||||||
|
def test_defaults(self) -> None:
|
||||||
|
s = PebbleState()
|
||||||
|
assert s.free_pebbles == {}
|
||||||
|
assert s.directed_edges == {}
|
||||||
|
assert s.independent_edges == set()
|
||||||
|
assert s.redundant_edges == set()
|
||||||
|
assert s.incoming == {}
|
||||||
|
assert s.outgoing == {}
|
||||||
|
|
||||||
|
def test_default_factory_isolation(self) -> None:
|
||||||
|
s1 = PebbleState()
|
||||||
|
s2 = PebbleState()
|
||||||
|
s1.free_pebbles[0] = 6
|
||||||
|
assert 0 not in s2.free_pebbles
|
||||||
|
|
||||||
|
|
||||||
|
class TestConstraintAnalysis:
|
||||||
|
"""ConstraintAnalysis dataclass construction."""
|
||||||
|
|
||||||
|
def test_construction(self) -> None:
|
||||||
|
ca = ConstraintAnalysis(
|
||||||
|
combinatorial_dof=6,
|
||||||
|
combinatorial_internal_dof=0,
|
||||||
|
combinatorial_redundant=0,
|
||||||
|
combinatorial_classification="well-constrained",
|
||||||
|
per_edge_results=[],
|
||||||
|
jacobian_rank=6,
|
||||||
|
jacobian_nullity=0,
|
||||||
|
jacobian_internal_dof=0,
|
||||||
|
numerically_dependent=[],
|
||||||
|
geometric_degeneracies=0,
|
||||||
|
is_rigid=True,
|
||||||
|
is_minimally_rigid=True,
|
||||||
|
)
|
||||||
|
assert ca.is_rigid is True
|
||||||
|
assert ca.is_minimally_rigid is True
|
||||||
|
assert ca.combinatorial_classification == "well-constrained"
|
||||||
|
|
||||||
|
def test_per_edge_results_typing(self) -> None:
|
||||||
|
"""per_edge_results accepts list[dict[str, Any]]."""
|
||||||
|
ca = ConstraintAnalysis(
|
||||||
|
combinatorial_dof=7,
|
||||||
|
combinatorial_internal_dof=1,
|
||||||
|
combinatorial_redundant=0,
|
||||||
|
combinatorial_classification="underconstrained",
|
||||||
|
per_edge_results=[{"edge_id": 0, "independent": True}],
|
||||||
|
jacobian_rank=5,
|
||||||
|
jacobian_nullity=1,
|
||||||
|
jacobian_internal_dof=1,
|
||||||
|
numerically_dependent=[],
|
||||||
|
geometric_degeneracies=0,
|
||||||
|
is_rigid=False,
|
||||||
|
is_minimally_rigid=False,
|
||||||
|
)
|
||||||
|
assert len(ca.per_edge_results) == 1
|
||||||
|
assert ca.per_edge_results[0]["edge_id"] == 0
|
||||||
Reference in New Issue
Block a user