Compare commits
25 Commits
e32c9cd793
...
public
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f68245e952 | ||
|
|
b088b74dcf | ||
|
|
c728bd93f7 | ||
|
|
bbbc5e0137 | ||
|
|
40cda51142 | ||
|
|
e45207b7cc | ||
|
|
537d8c7689 | ||
| 93bda28f67 | |||
| 239e45c7f9 | |||
| 118474f892 | |||
| e8143cf64c | |||
| 9f53fdb154 | |||
| 5d1988b513 | |||
| f29060491e | |||
| 8a49f8ef40 | |||
| 78289494e2 | |||
| 0b5813b5a9 | |||
| dc742bfc82 | |||
| 831a10cdb4 | |||
| 9a31df4988 | |||
| 455b6318d9 | |||
| 35d4ef736f | |||
| 1b6135129e | |||
| 363b49281b | |||
| f61d005400 |
180
.gitea/workflows/ci.yaml
Normal file
180
.gitea/workflows/ci.yaml
Normal file
@@ -0,0 +1,180 @@
|
||||
name: CI
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [main, public]
|
||||
pull_request:
|
||||
branches: [main, public]
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
run_datagen:
|
||||
description: "Run dataset generation"
|
||||
required: false
|
||||
type: boolean
|
||||
default: false
|
||||
num_assemblies:
|
||||
description: "Number of assemblies to generate"
|
||||
required: false
|
||||
type: string
|
||||
default: "100000"
|
||||
num_workers:
|
||||
description: "Parallel workers for datagen"
|
||||
required: false
|
||||
type: string
|
||||
default: "4"
|
||||
|
||||
env:
|
||||
PIP_CACHE_DIR: /tmp/pip-cache-solver
|
||||
TORCH_INDEX: https://download.pytorch.org/whl/cpu
|
||||
VIRTUAL_ENV: /tmp/solver-venv
|
||||
|
||||
jobs:
|
||||
# ---------------------------------------------------------------------------
|
||||
# Lint — fast, no torch required
|
||||
# ---------------------------------------------------------------------------
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
steps:
|
||||
- name: Checkout
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
||||
cd "$GITHUB_WORKSPACE"
|
||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
||||
|
||||
- name: Set up venv
|
||||
run: python3 -m venv $VIRTUAL_ENV
|
||||
|
||||
- name: Install lint tools
|
||||
run: pip install --cache-dir $PIP_CACHE_DIR ruff
|
||||
|
||||
- name: Ruff check
|
||||
run: ruff check solver/ freecad/ tests/ scripts/
|
||||
|
||||
- name: Ruff format check
|
||||
run: ruff format --check solver/ freecad/ tests/ scripts/
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Type check
|
||||
# ---------------------------------------------------------------------------
|
||||
type-check:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
steps:
|
||||
- name: Checkout
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
||||
cd "$GITHUB_WORKSPACE"
|
||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
||||
|
||||
- name: Set up venv
|
||||
run: python3 -m venv $VIRTUAL_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
||||
pip install --cache-dir $PIP_CACHE_DIR mypy numpy scipy
|
||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[dev]"
|
||||
|
||||
- name: Mypy
|
||||
run: mypy solver/ freecad/
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
steps:
|
||||
- name: Checkout
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
||||
cd "$GITHUB_WORKSPACE"
|
||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
||||
|
||||
- name: Set up venv
|
||||
run: python3 -m venv $VIRTUAL_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[train,dev]"
|
||||
|
||||
- name: Run tests
|
||||
run: pytest tests/ freecad/tests/ -v --tb=short
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset generation — manual trigger or on main/public push
|
||||
# ---------------------------------------------------------------------------
|
||||
datagen:
|
||||
runs-on: ubuntu-latest
|
||||
if: >-
|
||||
(github.event_name == 'workflow_dispatch' && inputs.run_datagen == true) ||
|
||||
(github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/public'))
|
||||
needs: [test]
|
||||
env:
|
||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
||||
steps:
|
||||
- name: Checkout
|
||||
run: |
|
||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
||||
cd "$GITHUB_WORKSPACE"
|
||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
||||
|
||||
- name: Set up venv
|
||||
run: python3 -m venv $VIRTUAL_ENV
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[train]"
|
||||
|
||||
- name: Generate dataset
|
||||
run: |
|
||||
NUM=${INPUTS_NUM_ASSEMBLIES:-100000}
|
||||
WORKERS=${INPUTS_NUM_WORKERS:-4}
|
||||
echo "Generating ${NUM} assemblies with ${WORKERS} workers"
|
||||
python3 scripts/generate_synthetic.py \
|
||||
--num-assemblies "${NUM}" \
|
||||
--num-workers "${WORKERS}" \
|
||||
--output-dir data/synthetic
|
||||
env:
|
||||
INPUTS_NUM_ASSEMBLIES: ${{ inputs.num_assemblies }}
|
||||
INPUTS_NUM_WORKERS: ${{ inputs.num_workers }}
|
||||
|
||||
- name: Print summary
|
||||
if: always()
|
||||
run: |
|
||||
echo "=== Dataset Generation Results ==="
|
||||
if [ -f data/synthetic/stats.json ]; then
|
||||
python3 -c "
|
||||
import json
|
||||
with open('data/synthetic/stats.json') as f:
|
||||
s = json.load(f)
|
||||
print(f'Total examples: {s[\"total_examples\"]}')
|
||||
print(f'Classification: {json.dumps(s[\"classification_distribution\"], indent=2)}')
|
||||
print(f'Rigid: {s[\"rigidity\"][\"rigid_fraction\"]*100:.1f}%')
|
||||
print(f'Degeneracy: {s[\"geometric_degeneracy\"][\"fraction_with_degeneracy\"]*100:.1f}%')
|
||||
"
|
||||
else
|
||||
echo "stats.json not found — generation may have failed"
|
||||
ls -la data/synthetic/ 2>/dev/null || echo "output dir missing"
|
||||
fi
|
||||
77
.gitignore
vendored
77
.gitignore
vendored
@@ -1,44 +1,83 @@
|
||||
# Prerequisites
|
||||
# C++ compiled objects
|
||||
*.d
|
||||
|
||||
# Compiled Object files
|
||||
*.slo
|
||||
*.lo
|
||||
*.o
|
||||
*.obj
|
||||
|
||||
# Precompiled Headers
|
||||
*.gch
|
||||
*.pch
|
||||
|
||||
# Compiled Dynamic libraries
|
||||
# C++ libraries
|
||||
*.so
|
||||
*.dylib
|
||||
*.dll
|
||||
|
||||
# Fortran module files
|
||||
*.mod
|
||||
*.smod
|
||||
|
||||
# Compiled Static libraries
|
||||
*.lai
|
||||
*.la
|
||||
*.a
|
||||
*.lib
|
||||
|
||||
# Executables
|
||||
# C++ executables
|
||||
*.exe
|
||||
*.out
|
||||
*.app
|
||||
|
||||
.vs
|
||||
# C++ build
|
||||
build/
|
||||
cmake-build-debug/
|
||||
.vs/
|
||||
x64/
|
||||
temp/
|
||||
|
||||
# OndselSolver test artifacts
|
||||
*.bak
|
||||
assembly.asmt
|
||||
|
||||
build
|
||||
cmake-build-debug
|
||||
.idea
|
||||
temp/
|
||||
/testapp/draggingBackhoe.log
|
||||
/testapp/runPreDragBackhoe.asmt
|
||||
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.egg-info/
|
||||
dist/
|
||||
*.egg
|
||||
|
||||
# Virtual environments
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
*.swp
|
||||
*.swo
|
||||
*~
|
||||
|
||||
# mypy / ruff / pytest
|
||||
.mypy_cache/
|
||||
.ruff_cache/
|
||||
.pytest_cache/
|
||||
|
||||
# Data (large files tracked separately)
|
||||
data/synthetic/*.pt
|
||||
data/fusion360/*.json
|
||||
data/fusion360/*.step
|
||||
data/processed/*.pt
|
||||
!data/**/.gitkeep
|
||||
|
||||
# Model checkpoints
|
||||
*.ckpt
|
||||
*.pth
|
||||
*.onnx
|
||||
*.torchscript
|
||||
|
||||
# Experiment tracking
|
||||
wandb/
|
||||
runs/
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Environment
|
||||
.env
|
||||
|
||||
23
.pre-commit-config.yaml
Normal file
23
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.3.4
|
||||
hooks:
|
||||
- id: ruff
|
||||
args: [--fix]
|
||||
- id: ruff-format
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v1.8.0
|
||||
hooks:
|
||||
- id: mypy
|
||||
additional_dependencies:
|
||||
- torch>=2.2
|
||||
- numpy>=1.26
|
||||
args: [--ignore-missing-imports]
|
||||
|
||||
- repo: https://github.com/compilerla/conventional-pre-commit
|
||||
rev: v3.1.0
|
||||
hooks:
|
||||
- id: conventional-pre-commit
|
||||
stages: [commit-msg]
|
||||
args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert]
|
||||
61
Dockerfile
Normal file
61
Dockerfile
Normal file
@@ -0,0 +1,61 @@
|
||||
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
# System deps
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
python3.11 python3.11-venv python3.11-dev python3-pip \
|
||||
git wget curl \
|
||||
# FreeCAD headless deps
|
||||
freecad \
|
||||
libgl1-mesa-glx libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
|
||||
|
||||
# Create venv
|
||||
RUN python -m venv /opt/venv
|
||||
ENV PATH="/opt/venv/bin:$PATH"
|
||||
|
||||
# Install PyTorch with CUDA
|
||||
RUN pip install --no-cache-dir \
|
||||
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
|
||||
|
||||
# Install PyG
|
||||
RUN pip install --no-cache-dir \
|
||||
torch-geometric \
|
||||
pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
|
||||
-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
# Install project
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]" || true
|
||||
|
||||
COPY . .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||
|
||||
# -------------------------------------------------------------------
|
||||
FROM base AS cpu
|
||||
|
||||
# CPU-only variant (for CI and non-GPU environments)
|
||||
FROM python:3.11-slim AS cpu-only
|
||||
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
git freecad libgl1-mesa-glx libglib2.0-0 \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
COPY pyproject.toml .
|
||||
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
|
||||
RUN pip install --no-cache-dir torch-geometric
|
||||
|
||||
COPY . .
|
||||
RUN pip install --no-cache-dir -e ".[train,dev]"
|
||||
|
||||
CMD ["pytest", "tests/", "-v"]
|
||||
48
Makefile
Normal file
48
Makefile
Normal file
@@ -0,0 +1,48 @@
|
||||
.PHONY: train test lint data-gen export format type-check install dev clean help
|
||||
|
||||
PYTHON ?= python
|
||||
PYTEST ?= pytest
|
||||
RUFF ?= ruff
|
||||
MYPY ?= mypy
|
||||
|
||||
help: ## Show this help
|
||||
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
|
||||
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
|
||||
|
||||
install: ## Install core dependencies
|
||||
pip install -e .
|
||||
|
||||
dev: ## Install all dependencies including dev tools
|
||||
pip install -e ".[train,dev]"
|
||||
pre-commit install
|
||||
pre-commit install --hook-type commit-msg
|
||||
|
||||
train: ## Run training (pass CONFIG=path/to/config.yaml)
|
||||
$(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG))
|
||||
|
||||
test: ## Run test suite
|
||||
$(PYTEST) tests/ freecad/tests/ -v --tb=short
|
||||
|
||||
lint: ## Run ruff linter
|
||||
$(RUFF) check solver/ freecad/ tests/ scripts/
|
||||
|
||||
format: ## Format code with ruff
|
||||
$(RUFF) format solver/ freecad/ tests/ scripts/
|
||||
$(RUFF) check --fix solver/ freecad/ tests/ scripts/
|
||||
|
||||
type-check: ## Run mypy type checker
|
||||
$(MYPY) solver/ freecad/
|
||||
|
||||
data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml)
|
||||
$(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG))
|
||||
|
||||
export: ## Export trained model for deployment
|
||||
$(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL))
|
||||
|
||||
clean: ## Remove build artifacts and caches
|
||||
rm -rf build/ dist/ *.egg-info/
|
||||
rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/
|
||||
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
|
||||
find . -type f -name "*.pyc" -delete 2>/dev/null || true
|
||||
|
||||
check: lint type-check test ## Run all checks (lint, type-check, test)
|
||||
94
README.md
94
README.md
@@ -1,7 +1,91 @@
|
||||
# MbDCode
|
||||
Assembly Constraints and Multibody Dynamics code
|
||||
# Kindred Solver
|
||||
|
||||
Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited)
|
||||
Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer.
|
||||
|
||||
The MbD theory is at
|
||||
https://github.com/Ondsel-Development/MbDTheory
|
||||
## Components
|
||||
|
||||
### OndselSolver (C++)
|
||||
|
||||
Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver.
|
||||
|
||||
- Source: `OndselSolver/`
|
||||
- Entry point: `OndselSolverMain/`
|
||||
- Tests: `tests/`, `testapp/`
|
||||
- Build: CMake
|
||||
|
||||
**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory)
|
||||
|
||||
#### Building
|
||||
|
||||
```bash
|
||||
cmake -B build -DCMAKE_BUILD_TYPE=Release
|
||||
cmake --build build
|
||||
```
|
||||
|
||||
### ML Solver Layer (Python)
|
||||
|
||||
Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns.
|
||||
|
||||
- Core library: `solver/`
|
||||
- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling)
|
||||
- Model architectures: `solver/models/` (GIN, GAT, NNConv)
|
||||
- Training: `solver/training/`
|
||||
- Inference: `solver/inference/`
|
||||
- FreeCAD integration: `freecad/`
|
||||
- Configuration: `configs/` (Hydra)
|
||||
|
||||
#### Setup
|
||||
|
||||
```bash
|
||||
pip install -e ".[train,dev]"
|
||||
pre-commit install
|
||||
```
|
||||
|
||||
#### Usage
|
||||
|
||||
```bash
|
||||
make help # show all targets
|
||||
make dev # install all deps + pre-commit hooks
|
||||
make test # run tests
|
||||
make lint # run ruff linter
|
||||
make check # lint + type-check + test
|
||||
make data-gen # generate synthetic data
|
||||
make train # run training
|
||||
make export # export model
|
||||
```
|
||||
|
||||
Docker is also supported:
|
||||
|
||||
```bash
|
||||
docker compose up train # GPU training
|
||||
docker compose up test # run tests
|
||||
docker compose up data-gen # generate synthetic data
|
||||
```
|
||||
|
||||
## Repository structure
|
||||
|
||||
```
|
||||
kindred-solver/
|
||||
├── OndselSolver/ # C++ numerical solver library
|
||||
├── OndselSolverMain/ # C++ solver CLI entry point
|
||||
├── tests/ # C++ unit tests + Python tests
|
||||
├── testapp/ # C++ test application
|
||||
├── solver/ # Python ML solver library
|
||||
│ ├── datagen/ # Synthetic data generation (pebble game)
|
||||
│ ├── datasets/ # PyG dataset adapters
|
||||
│ ├── models/ # GNN architectures
|
||||
│ ├── training/ # Training loops
|
||||
│ ├── evaluation/ # Metrics and visualization
|
||||
│ └── inference/ # Runtime prediction API
|
||||
├── freecad/ # FreeCAD workbench integration
|
||||
├── configs/ # Hydra configs (dataset, model, training, export)
|
||||
├── scripts/ # CLI utilities
|
||||
├── data/ # Datasets (not committed)
|
||||
├── export/ # Model packaging
|
||||
└── docs/ # Documentation
|
||||
```
|
||||
|
||||
## License
|
||||
|
||||
OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE))
|
||||
ML Solver Layer: Apache-2.0
|
||||
|
||||
12
configs/dataset/fusion360.yaml
Normal file
12
configs/dataset/fusion360.yaml
Normal file
@@ -0,0 +1,12 @@
|
||||
# Fusion 360 Gallery dataset config
|
||||
name: fusion360
|
||||
data_dir: data/fusion360
|
||||
output_dir: data/processed
|
||||
|
||||
splits:
|
||||
train: 0.8
|
||||
val: 0.1
|
||||
test: 0.1
|
||||
|
||||
stratify_by: complexity
|
||||
seed: 42
|
||||
26
configs/dataset/synthetic.yaml
Normal file
26
configs/dataset/synthetic.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
# Synthetic dataset generation config
|
||||
name: synthetic
|
||||
num_assemblies: 100000
|
||||
output_dir: data/synthetic
|
||||
shard_size: 1000
|
||||
|
||||
complexity_distribution:
|
||||
simple: 0.4 # 2-5 bodies
|
||||
medium: 0.4 # 6-15 bodies
|
||||
complex: 0.2 # 16-50 bodies
|
||||
|
||||
body_count:
|
||||
min: 2
|
||||
max: 50
|
||||
|
||||
templates:
|
||||
- chain
|
||||
- tree
|
||||
- loop
|
||||
- star
|
||||
- mixed
|
||||
|
||||
grounded_ratio: 0.5
|
||||
seed: 42
|
||||
num_workers: 4
|
||||
checkpoint_every: 5
|
||||
25
configs/export/production.yaml
Normal file
25
configs/export/production.yaml
Normal file
@@ -0,0 +1,25 @@
|
||||
# Production model export config
|
||||
model_checkpoint: checkpoints/finetune/best_val_loss.ckpt
|
||||
output_dir: export/
|
||||
|
||||
formats:
|
||||
onnx:
|
||||
enabled: true
|
||||
opset_version: 17
|
||||
dynamic_axes: true
|
||||
torchscript:
|
||||
enabled: true
|
||||
|
||||
model_card:
|
||||
version: "0.1.0"
|
||||
architecture: baseline
|
||||
training_data:
|
||||
- synthetic_100k
|
||||
- fusion360_gallery
|
||||
|
||||
size_budget_mb: 50
|
||||
|
||||
inference:
|
||||
device: cpu
|
||||
batch_size: 1
|
||||
confidence_threshold: 0.8
|
||||
24
configs/model/baseline.yaml
Normal file
24
configs/model/baseline.yaml
Normal file
@@ -0,0 +1,24 @@
|
||||
# Baseline GIN model config
|
||||
name: baseline
|
||||
architecture: gin
|
||||
|
||||
encoder:
|
||||
num_layers: 3
|
||||
hidden_dim: 128
|
||||
dropout: 0.1
|
||||
|
||||
node_features_dim: 22
|
||||
edge_features_dim: 22
|
||||
|
||||
heads:
|
||||
edge_classification:
|
||||
enabled: true
|
||||
hidden_dim: 64
|
||||
graph_classification:
|
||||
enabled: true
|
||||
num_classes: 4 # rigid, under, over, mixed
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
dof_regression:
|
||||
enabled: true
|
||||
28
configs/model/gat.yaml
Normal file
28
configs/model/gat.yaml
Normal file
@@ -0,0 +1,28 @@
|
||||
# Advanced GAT model config
|
||||
name: gat_solver
|
||||
architecture: gat
|
||||
|
||||
encoder:
|
||||
num_layers: 4
|
||||
hidden_dim: 256
|
||||
num_heads: 8
|
||||
dropout: 0.1
|
||||
residual: true
|
||||
|
||||
node_features_dim: 22
|
||||
edge_features_dim: 22
|
||||
|
||||
heads:
|
||||
edge_classification:
|
||||
enabled: true
|
||||
hidden_dim: 128
|
||||
graph_classification:
|
||||
enabled: true
|
||||
num_classes: 4
|
||||
joint_type:
|
||||
enabled: true
|
||||
num_classes: 12
|
||||
dof_regression:
|
||||
enabled: true
|
||||
dof_tracking:
|
||||
enabled: true
|
||||
45
configs/training/finetune.yaml
Normal file
45
configs/training/finetune.yaml
Normal file
@@ -0,0 +1,45 @@
|
||||
# Fine-tuning on real data config
|
||||
phase: finetune
|
||||
|
||||
dataset: fusion360
|
||||
model: baseline
|
||||
|
||||
pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt
|
||||
|
||||
optimizer:
|
||||
name: adamw
|
||||
lr: 1e-5
|
||||
weight_decay: 1e-4
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
T_max: 50
|
||||
eta_min: 1e-7
|
||||
|
||||
training:
|
||||
epochs: 50
|
||||
batch_size: 32
|
||||
gradient_clip: 1.0
|
||||
early_stopping_patience: 10
|
||||
amp: true
|
||||
freeze_encoder: false # set true for frozen encoder experiment
|
||||
|
||||
loss:
|
||||
edge_weight: 1.0
|
||||
graph_weight: 0.5
|
||||
joint_type_weight: 0.3
|
||||
dof_weight: 0.2
|
||||
redundant_penalty: 2.0
|
||||
|
||||
checkpointing:
|
||||
save_best_val_loss: true
|
||||
save_best_val_accuracy: true
|
||||
save_every_n_epochs: 5
|
||||
checkpoint_dir: checkpoints/finetune
|
||||
|
||||
logging:
|
||||
backend: wandb
|
||||
project: kindred-solver
|
||||
log_every_n_steps: 20
|
||||
|
||||
seed: 42
|
||||
42
configs/training/pretrain.yaml
Normal file
42
configs/training/pretrain.yaml
Normal file
@@ -0,0 +1,42 @@
|
||||
# Synthetic pre-training config
|
||||
phase: pretrain
|
||||
|
||||
dataset: synthetic
|
||||
model: baseline
|
||||
|
||||
optimizer:
|
||||
name: adamw
|
||||
lr: 1e-3
|
||||
weight_decay: 1e-4
|
||||
|
||||
scheduler:
|
||||
name: cosine_annealing
|
||||
T_max: 100
|
||||
eta_min: 1e-6
|
||||
|
||||
training:
|
||||
epochs: 100
|
||||
batch_size: 64
|
||||
gradient_clip: 1.0
|
||||
early_stopping_patience: 10
|
||||
amp: true
|
||||
|
||||
loss:
|
||||
edge_weight: 1.0
|
||||
graph_weight: 0.5
|
||||
joint_type_weight: 0.3
|
||||
dof_weight: 0.2
|
||||
redundant_penalty: 2.0 # safety loss multiplier
|
||||
|
||||
checkpointing:
|
||||
save_best_val_loss: true
|
||||
save_best_val_accuracy: true
|
||||
save_every_n_epochs: 10
|
||||
checkpoint_dir: checkpoints/pretrain
|
||||
|
||||
logging:
|
||||
backend: wandb # or tensorboard
|
||||
project: kindred-solver
|
||||
log_every_n_steps: 50
|
||||
|
||||
seed: 42
|
||||
0
data/fusion360/.gitkeep
Normal file
0
data/fusion360/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/processed/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/splits/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
0
data/synthetic/.gitkeep
Normal file
39
docker-compose.yml
Normal file
39
docker-compose.yml
Normal file
@@ -0,0 +1,39 @@
|
||||
services:
|
||||
train:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: base
|
||||
volumes:
|
||||
- .:/workspace
|
||||
- ./data:/workspace/data
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: all
|
||||
capabilities: [gpu]
|
||||
command: make train
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
|
||||
- WANDB_API_KEY=${WANDB_API_KEY:-}
|
||||
|
||||
test:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: cpu-only
|
||||
volumes:
|
||||
- .:/workspace
|
||||
command: make check
|
||||
|
||||
data-gen:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
target: base
|
||||
volumes:
|
||||
- .:/workspace
|
||||
- ./data:/workspace/data
|
||||
command: make data-gen
|
||||
0
docs/.gitkeep
Normal file
0
docs/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
export/.gitkeep
Normal file
0
freecad/__init__.py
Normal file
0
freecad/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/bridge/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/tests/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
0
freecad/workbench/__init__.py
Normal file
97
pyproject.toml
Normal file
97
pyproject.toml
Normal file
@@ -0,0 +1,97 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[project]
|
||||
name = "kindred-solver"
|
||||
version = "0.1.0"
|
||||
description = "Assembly constraint prediction via GNN for Kindred Create"
|
||||
readme = "README.md"
|
||||
license = "Apache-2.0"
|
||||
requires-python = ">=3.11"
|
||||
authors = [
|
||||
{ name = "Kindred Systems" },
|
||||
]
|
||||
classifiers = [
|
||||
"Development Status :: 3 - Alpha",
|
||||
"License :: OSI Approved :: Apache Software License",
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Topic :: Scientific/Engineering",
|
||||
]
|
||||
dependencies = [
|
||||
"torch>=2.2",
|
||||
"torch-geometric>=2.5",
|
||||
"numpy>=1.26",
|
||||
"scipy>=1.12",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
train = [
|
||||
"wandb>=0.16",
|
||||
"tensorboard>=2.16",
|
||||
"hydra-core>=1.3",
|
||||
"omegaconf>=2.3",
|
||||
"matplotlib>=3.8",
|
||||
"networkx>=3.2",
|
||||
]
|
||||
freecad = [
|
||||
"pyside6>=6.6",
|
||||
]
|
||||
dev = [
|
||||
"pytest>=8.0",
|
||||
"pytest-cov>=4.1",
|
||||
"ruff>=0.3",
|
||||
"mypy>=1.8",
|
||||
"pre-commit>=3.6",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Repository = "https://git.kindred-systems.com/kindred/solver"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["solver", "freecad"]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py311"
|
||||
line-length = 100
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle errors
|
||||
"W", # pycodestyle warnings
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
"N", # pep8-naming
|
||||
"UP", # pyupgrade
|
||||
"B", # flake8-bugbear
|
||||
"SIM", # flake8-simplify
|
||||
"TCH", # flake8-type-checking
|
||||
"RUF", # ruff-specific
|
||||
]
|
||||
|
||||
[tool.ruff.lint.isort]
|
||||
known-first-party = ["solver", "freecad"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.11"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = true
|
||||
check_untyped_defs = true
|
||||
|
||||
[[tool.mypy.overrides]]
|
||||
module = [
|
||||
"torch.*",
|
||||
"torch_geometric.*",
|
||||
"scipy.*",
|
||||
"wandb.*",
|
||||
"hydra.*",
|
||||
"omegaconf.*",
|
||||
]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
testpaths = ["tests", "freecad/tests"]
|
||||
addopts = "-v --tb=short"
|
||||
115
scripts/generate_synthetic.py
Normal file
115
scripts/generate_synthetic.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate synthetic assembly dataset for kindred-solver training.
|
||||
|
||||
Usage (argparse fallback — always available)::
|
||||
|
||||
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
||||
|
||||
Usage (Hydra — when hydra-core is installed)::
|
||||
|
||||
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _try_hydra_main() -> bool:
|
||||
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
||||
try:
|
||||
import hydra # type: ignore[import-untyped]
|
||||
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@hydra.main(
|
||||
config_path="../configs/dataset",
|
||||
config_name="synthetic",
|
||||
version_base=None,
|
||||
)
|
||||
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
|
||||
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
_run() # type: ignore[no-untyped-call]
|
||||
return True
|
||||
|
||||
|
||||
def _argparse_main() -> None:
|
||||
"""Fallback CLI using argparse."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic assembly dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to YAML config file (optional)",
|
||||
)
|
||||
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
||||
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
||||
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
||||
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
||||
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
||||
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Checkpoint interval (shards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Do not resume from existing checkpoints",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
config_dict: dict[str, object] = {}
|
||||
if args.config:
|
||||
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
||||
|
||||
# CLI args override config file (only when explicitly provided)
|
||||
_override_map = {
|
||||
"num_assemblies": args.num_assemblies,
|
||||
"output_dir": args.output_dir,
|
||||
"shard_size": args.shard_size,
|
||||
"body_count_min": args.body_count_min,
|
||||
"body_count_max": args.body_count_max,
|
||||
"grounded_ratio": args.grounded_ratio,
|
||||
"seed": args.seed,
|
||||
"num_workers": args.num_workers,
|
||||
"checkpoint_every": args.checkpoint_every,
|
||||
}
|
||||
for key, val in _override_map.items():
|
||||
if val is not None:
|
||||
config_dict[key] = val
|
||||
|
||||
if args.no_resume:
|
||||
config_dict["resume"] = False
|
||||
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point: try Hydra first, fall back to argparse."""
|
||||
if not _try_hydra_main():
|
||||
_argparse_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
0
solver/__init__.py
Normal file
0
solver/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
37
solver/datagen/__init__.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
from solver.datagen.generator import (
|
||||
COMPLEXITY_RANGES,
|
||||
AxisStrategy,
|
||||
SyntheticAssemblyGenerator,
|
||||
)
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.labeling import AssemblyLabels, label_assembly
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
PebbleState,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AssemblyLabels",
|
||||
"AxisStrategy",
|
||||
"ConstraintAnalysis",
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"JacobianVerifier",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleGame3D",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
"SyntheticAssemblyGenerator",
|
||||
"analyze_assembly",
|
||||
"label_assembly",
|
||||
]
|
||||
140
solver/datagen/analysis.py
Normal file
140
solver/datagen/analysis.py
Normal file
@@ -0,0 +1,140 @@
|
||||
"""Combined pebble game + Jacobian verification analysis.
|
||||
|
||||
Provides :func:`analyze_assembly`, the main entry point for full rigidity
|
||||
analysis of an assembly using both combinatorial and numerical methods.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
__all__ = ["analyze_assembly"]
|
||||
|
||||
_GROUND_ID = -1
|
||||
|
||||
|
||||
def analyze_assembly(
|
||||
bodies: list[RigidBody],
|
||||
joints: list[Joint],
|
||||
ground_body: int | None = None,
|
||||
) -> ConstraintAnalysis:
|
||||
"""Full rigidity analysis of an assembly using both methods.
|
||||
|
||||
Args:
|
||||
bodies: List of rigid bodies in the assembly.
|
||||
joints: List of joints connecting bodies.
|
||||
ground_body: If set, this body is fixed (adds 6 implicit constraints).
|
||||
|
||||
Returns:
|
||||
ConstraintAnalysis with combinatorial and numerical results.
|
||||
"""
|
||||
# --- Pebble Game ---
|
||||
pg = PebbleGame3D()
|
||||
all_edge_results = []
|
||||
|
||||
# Add a virtual ground body (id=-1) if grounding is requested.
|
||||
# Grounding body X means adding a fixed joint between X and
|
||||
# the virtual ground. This properly lets the pebble game account
|
||||
# for the 6 removed DOF without breaking invariants.
|
||||
if ground_body is not None:
|
||||
pg.add_body(_GROUND_ID)
|
||||
|
||||
for body in bodies:
|
||||
pg.add_body(body.body_id)
|
||||
|
||||
if ground_body is not None:
|
||||
ground_joint = Joint(
|
||||
joint_id=-1,
|
||||
body_a=ground_body,
|
||||
body_b=_GROUND_ID,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
# Don't include ground joint edges in the output labels
|
||||
# (they're infrastructure, not user constraints)
|
||||
|
||||
for joint in joints:
|
||||
results = pg.add_joint(joint)
|
||||
all_edge_results.extend(results)
|
||||
|
||||
combinatorial_independent = len(pg.state.independent_edges)
|
||||
grounded = ground_body is not None
|
||||
|
||||
# The virtual ground body contributes 6 pebbles to the total.
|
||||
# Subtract those from the reported DOF for user-facing numbers.
|
||||
raw_dof = pg.get_dof()
|
||||
ground_offset = 6 if grounded else 0
|
||||
effective_dof = raw_dof - ground_offset
|
||||
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||
|
||||
# Classify based on effective (adjusted) DOF, not raw pebble game output,
|
||||
# because the virtual ground body skews the raw numbers.
|
||||
redundant = pg.get_redundant_count()
|
||||
if redundant > 0 and effective_internal_dof > 0:
|
||||
combinatorial_classification = "mixed"
|
||||
elif redundant > 0:
|
||||
combinatorial_classification = "overconstrained"
|
||||
elif effective_internal_dof > 0:
|
||||
combinatorial_classification = "underconstrained"
|
||||
else:
|
||||
combinatorial_classification = "well-constrained"
|
||||
|
||||
# --- Jacobian Verification ---
|
||||
verifier = JacobianVerifier(bodies)
|
||||
|
||||
for joint in joints:
|
||||
verifier.add_joint_constraints(joint)
|
||||
|
||||
# If grounded, remove the ground body's columns (fix its DOF)
|
||||
j = verifier.get_jacobian()
|
||||
if ground_body is not None and j.size > 0:
|
||||
idx = verifier.body_index[ground_body]
|
||||
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||
j = np.delete(j, cols_to_remove, axis=1)
|
||||
|
||||
if j.size > 0:
|
||||
sv = np.linalg.svd(j, compute_uv=False)
|
||||
jacobian_rank = int(np.sum(sv > 1e-8))
|
||||
else:
|
||||
jacobian_rank = 0
|
||||
|
||||
n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies)
|
||||
jacobian_nullity = n_cols - jacobian_rank
|
||||
|
||||
dependent = verifier.find_dependencies()
|
||||
|
||||
# Adjust for ground
|
||||
trivial_dof = 0 if ground_body is not None else 6
|
||||
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||
|
||||
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||
|
||||
# Rigidity: numerically rigid if nullity == trivial DOF
|
||||
is_rigid = jacobian_nullity <= trivial_dof
|
||||
is_minimally_rigid = is_rigid and len(dependent) == 0
|
||||
|
||||
return ConstraintAnalysis(
|
||||
combinatorial_dof=effective_dof,
|
||||
combinatorial_internal_dof=effective_internal_dof,
|
||||
combinatorial_redundant=pg.get_redundant_count(),
|
||||
combinatorial_classification=combinatorial_classification,
|
||||
per_edge_results=all_edge_results,
|
||||
jacobian_rank=jacobian_rank,
|
||||
jacobian_nullity=jacobian_nullity,
|
||||
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||
numerically_dependent=dependent,
|
||||
geometric_degeneracies=geometric_degeneracies,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
)
|
||||
624
solver/datagen/dataset.py
Normal file
624
solver/datagen/dataset.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
|
||||
|
||||
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
|
||||
for parallel generation of synthetic assembly training data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"parse_simple_yaml",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for synthetic dataset generation."""
|
||||
|
||||
name: str = "synthetic"
|
||||
num_assemblies: int = 100_000
|
||||
output_dir: str = "data/synthetic"
|
||||
shard_size: int = 1000
|
||||
complexity_distribution: dict[str, float] = field(
|
||||
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
|
||||
)
|
||||
body_count_min: int = 2
|
||||
body_count_max: int = 50
|
||||
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
|
||||
grounded_ratio: float = 0.5
|
||||
seed: int = 42
|
||||
num_workers: int = 4
|
||||
checkpoint_every: int = 5
|
||||
resume: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
|
||||
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
|
||||
|
||||
Handles both flat keys (``body_count_min``) and nested forms
|
||||
(``body_count: {min: 2, max: 50}``).
|
||||
"""
|
||||
kw: dict[str, Any] = {}
|
||||
for key in (
|
||||
"name",
|
||||
"num_assemblies",
|
||||
"output_dir",
|
||||
"shard_size",
|
||||
"grounded_ratio",
|
||||
"seed",
|
||||
"num_workers",
|
||||
"checkpoint_every",
|
||||
"resume",
|
||||
):
|
||||
if key in d:
|
||||
kw[key] = d[key]
|
||||
|
||||
# Handle nested body_count dict
|
||||
if "body_count" in d and isinstance(d["body_count"], dict):
|
||||
bc = d["body_count"]
|
||||
if "min" in bc:
|
||||
kw["body_count_min"] = int(bc["min"])
|
||||
if "max" in bc:
|
||||
kw["body_count_max"] = int(bc["max"])
|
||||
else:
|
||||
if "body_count_min" in d:
|
||||
kw["body_count_min"] = int(d["body_count_min"])
|
||||
if "body_count_max" in d:
|
||||
kw["body_count_max"] = int(d["body_count_max"])
|
||||
|
||||
if "complexity_distribution" in d:
|
||||
cd = d["complexity_distribution"]
|
||||
if isinstance(cd, dict):
|
||||
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
|
||||
if "templates" in d and isinstance(d["templates"], list):
|
||||
kw["templates"] = [str(t) for t in d["templates"]]
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard specification / result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardSpec:
|
||||
"""Specification for generating a single shard."""
|
||||
|
||||
shard_id: int
|
||||
start_example_id: int
|
||||
count: int
|
||||
seed: int
|
||||
complexity_distribution: dict[str, float]
|
||||
body_count_min: int
|
||||
body_count_max: int
|
||||
grounded_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardResult:
|
||||
"""Result returned from a shard worker."""
|
||||
|
||||
shard_id: int
|
||||
num_examples: int
|
||||
file_path: str
|
||||
generation_time_s: float
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
|
||||
"""Derive a deterministic per-shard seed from the global seed."""
|
||||
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
|
||||
return int(h[:8], 16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PrintProgress:
|
||||
"""Fallback progress display when tqdm is unavailable."""
|
||||
|
||||
def __init__(self, total: int) -> None:
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
def update(self, n: int = 1) -> None:
|
||||
self.current += n
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
rate = self.current / elapsed if elapsed > 0 else 0.0
|
||||
eta = (self.total - self.current) / rate if rate > 0 else 0.0
|
||||
pct = 100.0 * self.current / self.total
|
||||
sys.stdout.write(
|
||||
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
|
||||
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _make_progress(total: int) -> _PrintProgress:
|
||||
"""Create a progress tracker (tqdm if available, else print-based)."""
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore[import-untyped]
|
||||
|
||||
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
|
||||
except ImportError:
|
||||
return _PrintProgress(total)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _save_shard(
|
||||
shard_id: int,
|
||||
examples: list[dict[str, Any]],
|
||||
shards_dir: Path,
|
||||
) -> Path:
|
||||
"""Save a shard to disk (.pt if torch available, else .json)."""
|
||||
shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
path = shards_dir / f"shard_{shard_id:05d}.pt"
|
||||
torch.save(examples, path)
|
||||
except ImportError:
|
||||
path = shards_dir / f"shard_{shard_id:05d}.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(examples, f)
|
||||
return path
|
||||
|
||||
|
||||
def _load_shard(path: Path) -> list[dict[str, Any]]:
|
||||
"""Load a shard from disk (.pt or .json)."""
|
||||
if path.suffix == ".pt":
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
|
||||
return result
|
||||
with open(path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
|
||||
def _shard_format() -> str:
|
||||
"""Return the shard file extension based on available libraries."""
|
||||
try:
|
||||
import torch # type: ignore[import-untyped] # noqa: F401
|
||||
|
||||
return ".pt"
|
||||
except ImportError:
|
||||
return ".json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard worker (module-level for pickling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
|
||||
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
|
||||
t0 = time.monotonic()
|
||||
gen = SyntheticAssemblyGenerator(seed=spec.seed)
|
||||
rng = np.random.default_rng(spec.seed + 1)
|
||||
|
||||
tiers = list(spec.complexity_distribution.keys())
|
||||
probs_list = [spec.complexity_distribution[t] for t in tiers]
|
||||
total = sum(probs_list)
|
||||
probs = [p / total for p in probs_list]
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
for i in range(spec.count):
|
||||
tier_idx = int(rng.choice(len(tiers), p=probs))
|
||||
tier = tiers[tier_idx]
|
||||
try:
|
||||
batch = gen.generate_training_batch(
|
||||
batch_size=1,
|
||||
complexity_tier=tier, # type: ignore[arg-type]
|
||||
grounded_ratio=spec.grounded_ratio,
|
||||
)
|
||||
ex = batch[0]
|
||||
ex["example_id"] = spec.start_example_id + i
|
||||
ex["complexity_tier"] = tier
|
||||
examples.append(ex)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Shard %d, example %d failed — skipping",
|
||||
spec.shard_id,
|
||||
i,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
shards_dir = Path(output_dir) / "shards"
|
||||
path = _save_shard(spec.shard_id, examples, shards_dir)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
return ShardResult(
|
||||
shard_id=spec.shard_id,
|
||||
num_examples=len(examples),
|
||||
file_path=str(path),
|
||||
generation_time_s=elapsed,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_scalar(value: str) -> int | float | bool | str:
|
||||
"""Parse a YAML scalar value."""
|
||||
# Strip inline comments (space + #)
|
||||
if " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
elif " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
v = value.strip()
|
||||
if v.lower() in ("true", "yes"):
|
||||
return True
|
||||
if v.lower() in ("false", "no"):
|
||||
return False
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(v)
|
||||
except ValueError:
|
||||
pass
|
||||
return v.strip("'\"")
|
||||
|
||||
|
||||
def parse_simple_yaml(path: str) -> dict[str, Any]:
|
||||
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
|
||||
|
||||
This is **not** a full YAML parser. It handles the structure of
|
||||
``configs/dataset/synthetic.yaml``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
current_key: str | None = None
|
||||
|
||||
with open(path) as f:
|
||||
for raw_line in f:
|
||||
line = raw_line.rstrip()
|
||||
|
||||
# Skip blank lines and full-line comments
|
||||
if not line or line.lstrip().startswith("#"):
|
||||
continue
|
||||
|
||||
indent = len(line) - len(line.lstrip())
|
||||
|
||||
if indent == 0 and ":" in line:
|
||||
key, _, value = line.partition(":")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if value:
|
||||
result[key] = _parse_scalar(value)
|
||||
current_key = None
|
||||
else:
|
||||
current_key = key
|
||||
result[key] = {}
|
||||
continue
|
||||
|
||||
if indent > 0 and line.lstrip().startswith("- "):
|
||||
item = line.lstrip()[2:].strip()
|
||||
if current_key is not None:
|
||||
if isinstance(result.get(current_key), dict) and not result[current_key]:
|
||||
result[current_key] = []
|
||||
if isinstance(result.get(current_key), list):
|
||||
result[current_key].append(_parse_scalar(item))
|
||||
continue
|
||||
|
||||
if indent > 0 and ":" in line and current_key is not None:
|
||||
k, _, v = line.partition(":")
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if v:
|
||||
# Strip inline comments
|
||||
if " #" in v:
|
||||
v = v[: v.index(" #")].strip()
|
||||
if not isinstance(result.get(current_key), dict):
|
||||
result[current_key] = {}
|
||||
result[current_key][k] = _parse_scalar(v)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset generator orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
|
||||
|
||||
def __init__(self, config: DatasetConfig) -> None:
|
||||
self.config = config
|
||||
self.output_path = Path(config.output_dir)
|
||||
self.shards_dir = self.output_path / "shards"
|
||||
self.checkpoint_file = self.output_path / ".checkpoint.json"
|
||||
self.index_file = self.output_path / "index.json"
|
||||
self.stats_file = self.output_path / "stats.json"
|
||||
|
||||
# -- public API --
|
||||
|
||||
def run(self) -> None:
|
||||
"""Generate the full dataset."""
|
||||
self.output_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = self._plan_shards()
|
||||
total_shards = len(shards)
|
||||
|
||||
# Resume: find already-completed shards
|
||||
completed: set[int] = set()
|
||||
if self.config.resume:
|
||||
completed = self._find_completed_shards()
|
||||
|
||||
pending = [s for s in shards if s.shard_id not in completed]
|
||||
|
||||
if not pending:
|
||||
logger.info("All %d shards already complete.", total_shards)
|
||||
else:
|
||||
logger.info(
|
||||
"Generating %d shards (%d already complete).",
|
||||
len(pending),
|
||||
len(completed),
|
||||
)
|
||||
progress = _make_progress(len(pending))
|
||||
workers = max(1, self.config.num_workers)
|
||||
checkpoint_counter = 0
|
||||
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
_generate_shard_worker,
|
||||
spec,
|
||||
str(self.output_path),
|
||||
): spec.shard_id
|
||||
for spec in pending
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
shard_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
completed.add(result.shard_id)
|
||||
logger.debug(
|
||||
"Shard %d: %d examples in %.1fs",
|
||||
result.shard_id,
|
||||
result.num_examples,
|
||||
result.generation_time_s,
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Shard %d failed", shard_id, exc_info=True)
|
||||
progress.update(1)
|
||||
checkpoint_counter += 1
|
||||
if checkpoint_counter >= self.config.checkpoint_every:
|
||||
self._update_checkpoint(completed, total_shards)
|
||||
checkpoint_counter = 0
|
||||
|
||||
progress.close()
|
||||
|
||||
# Finalize
|
||||
self._build_index()
|
||||
stats = self._compute_statistics()
|
||||
self._write_statistics(stats)
|
||||
self._print_summary(stats)
|
||||
|
||||
# Remove checkpoint (generation complete)
|
||||
if self.checkpoint_file.exists():
|
||||
self.checkpoint_file.unlink()
|
||||
|
||||
# -- internal helpers --
|
||||
|
||||
def _plan_shards(self) -> list[ShardSpec]:
|
||||
"""Divide num_assemblies into shards."""
|
||||
n = self.config.num_assemblies
|
||||
size = self.config.shard_size
|
||||
num_shards = math.ceil(n / size)
|
||||
shards: list[ShardSpec] = []
|
||||
for i in range(num_shards):
|
||||
start = i * size
|
||||
count = min(size, n - start)
|
||||
shards.append(
|
||||
ShardSpec(
|
||||
shard_id=i,
|
||||
start_example_id=start,
|
||||
count=count,
|
||||
seed=_derive_shard_seed(self.config.seed, i),
|
||||
complexity_distribution=dict(self.config.complexity_distribution),
|
||||
body_count_min=self.config.body_count_min,
|
||||
body_count_max=self.config.body_count_max,
|
||||
grounded_ratio=self.config.grounded_ratio,
|
||||
)
|
||||
)
|
||||
return shards
|
||||
|
||||
def _find_completed_shards(self) -> set[int]:
|
||||
"""Scan shards directory for existing shard files."""
|
||||
completed: set[int] = set()
|
||||
if not self.shards_dir.exists():
|
||||
return completed
|
||||
|
||||
for p in self.shards_dir.iterdir():
|
||||
if p.stem.startswith("shard_"):
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
# Verify file is non-empty
|
||||
if p.stat().st_size > 0:
|
||||
completed.add(shard_id)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return completed
|
||||
|
||||
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
|
||||
"""Write checkpoint file."""
|
||||
data = {
|
||||
"completed_shards": sorted(completed),
|
||||
"total_shards": total_shards,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
with open(self.checkpoint_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _build_index(self) -> None:
|
||||
"""Build index.json mapping shard files to assembly ID ranges."""
|
||||
shards_info: dict[str, dict[str, int]] = {}
|
||||
total_assemblies = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
count = len(examples)
|
||||
start_id = shard_id * self.config.shard_size
|
||||
shards_info[p.name] = {"start_id": start_id, "count": count}
|
||||
total_assemblies += count
|
||||
|
||||
fmt = _shard_format().lstrip(".")
|
||||
index = {
|
||||
"format_version": 1,
|
||||
"total_assemblies": total_assemblies,
|
||||
"total_shards": len(shards_info),
|
||||
"shard_format": fmt,
|
||||
"shards": shards_info,
|
||||
}
|
||||
with open(self.index_file, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
def _compute_statistics(self) -> dict[str, Any]:
|
||||
"""Aggregate statistics across all shards."""
|
||||
classification_counts: dict[str, int] = {}
|
||||
body_count_hist: dict[int, int] = {}
|
||||
joint_type_counts: dict[str, int] = {}
|
||||
dof_values: list[int] = []
|
||||
degeneracy_values: list[int] = []
|
||||
rigid_count = 0
|
||||
minimally_rigid_count = 0
|
||||
total = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
for ex in examples:
|
||||
total += 1
|
||||
cls = str(ex.get("assembly_classification", "unknown"))
|
||||
classification_counts[cls] = classification_counts.get(cls, 0) + 1
|
||||
nb = int(ex.get("n_bodies", 0))
|
||||
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
|
||||
for j in ex.get("joints", []):
|
||||
jt = str(j.get("type", "unknown"))
|
||||
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
|
||||
dof_values.append(int(ex.get("internal_dof", 0)))
|
||||
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
|
||||
if ex.get("is_rigid"):
|
||||
rigid_count += 1
|
||||
if ex.get("is_minimally_rigid"):
|
||||
minimally_rigid_count += 1
|
||||
|
||||
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
|
||||
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
|
||||
|
||||
return {
|
||||
"total_examples": total,
|
||||
"classification_distribution": dict(sorted(classification_counts.items())),
|
||||
"body_count_histogram": dict(sorted(body_count_hist.items())),
|
||||
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
|
||||
"dof_statistics": {
|
||||
"mean": float(dof_arr.mean()),
|
||||
"std": float(dof_arr.std()),
|
||||
"min": int(dof_arr.min()),
|
||||
"max": int(dof_arr.max()),
|
||||
"median": float(np.median(dof_arr)),
|
||||
},
|
||||
"geometric_degeneracy": {
|
||||
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
|
||||
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
|
||||
"mean_degeneracies": float(deg_arr.mean()),
|
||||
},
|
||||
"rigidity": {
|
||||
"rigid_count": rigid_count,
|
||||
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
|
||||
"minimally_rigid_count": minimally_rigid_count,
|
||||
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
def _write_statistics(self, stats: dict[str, Any]) -> None:
|
||||
"""Write stats.json."""
|
||||
with open(self.stats_file, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
def _print_summary(self, stats: dict[str, Any]) -> None:
|
||||
"""Print a human-readable summary to stdout."""
|
||||
print("\n=== Dataset Generation Summary ===")
|
||||
print(f"Total examples: {stats['total_examples']}")
|
||||
print(f"Output directory: {self.output_path}")
|
||||
print()
|
||||
print("Classification distribution:")
|
||||
for cls, count in stats["classification_distribution"].items():
|
||||
frac = count / max(stats["total_examples"], 1) * 100
|
||||
print(f" {cls}: {count} ({frac:.1f}%)")
|
||||
print()
|
||||
print("Joint type distribution:")
|
||||
for jt, count in stats["joint_type_distribution"].items():
|
||||
print(f" {jt}: {count}")
|
||||
print()
|
||||
dof = stats["dof_statistics"]
|
||||
print(
|
||||
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
|
||||
)
|
||||
rig = stats["rigidity"]
|
||||
print(
|
||||
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
|
||||
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
|
||||
f"{rig['minimally_rigid_count']} minimally rigid"
|
||||
)
|
||||
deg = stats["geometric_degeneracy"]
|
||||
print(
|
||||
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
|
||||
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
|
||||
)
|
||||
893
solver/datagen/generator.py
Normal file
893
solver/datagen/generator.py
Normal file
@@ -0,0 +1,893 @@
|
||||
"""Synthetic assembly graph generator for training data production.
|
||||
|
||||
Generates assembly graphs with known constraint classifications using
|
||||
the pebble game and Jacobian verification. Each assembly is fully labeled
|
||||
with per-constraint independence flags and assembly-level classification.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import numpy as np
|
||||
from scipy.spatial.transform import Rotation
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.labeling import label_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"COMPLEXITY_RANGES",
|
||||
"AxisStrategy",
|
||||
"ComplexityTier",
|
||||
"SyntheticAssemblyGenerator",
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ComplexityTier = Literal["simple", "medium", "complex"]
|
||||
|
||||
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
|
||||
"simple": (2, 6),
|
||||
"medium": (6, 16),
|
||||
"complex": (16, 51),
|
||||
}
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
|
||||
|
||||
|
||||
class SyntheticAssemblyGenerator:
|
||||
"""Generates assembly graphs with known minimal constraint sets.
|
||||
|
||||
Uses the pebble game to incrementally build assemblies, tracking
|
||||
exactly which constraints are independent at each step. This produces
|
||||
labeled training data: (assembly_graph, constraint_set, labels).
|
||||
|
||||
Labels per constraint:
|
||||
- independent: bool (does this constraint remove a DOF?)
|
||||
- redundant: bool (is this constraint overconstrained?)
|
||||
- minimal_set: bool (part of a minimal rigidity basis?)
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.rng = np.random.default_rng(seed)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Private helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _random_position(self, scale: float = 5.0) -> np.ndarray:
|
||||
"""Generate random 3D position within [-scale, scale] cube."""
|
||||
return self.rng.uniform(-scale, scale, size=3)
|
||||
|
||||
def _random_axis(self) -> np.ndarray:
|
||||
"""Generate random normalized 3D axis."""
|
||||
axis = self.rng.standard_normal(3)
|
||||
axis /= np.linalg.norm(axis)
|
||||
return axis
|
||||
|
||||
def _random_orientation(self) -> np.ndarray:
|
||||
"""Generate a random 3x3 rotation matrix."""
|
||||
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
|
||||
return mat
|
||||
|
||||
def _cardinal_axis(self) -> np.ndarray:
|
||||
"""Pick uniformly from the six signed cardinal directions."""
|
||||
axes = np.array(
|
||||
[
|
||||
[1, 0, 0],
|
||||
[-1, 0, 0],
|
||||
[0, 1, 0],
|
||||
[0, -1, 0],
|
||||
[0, 0, 1],
|
||||
[0, 0, -1],
|
||||
],
|
||||
dtype=float,
|
||||
)
|
||||
result: np.ndarray = axes[int(self.rng.integers(6))]
|
||||
return result
|
||||
|
||||
def _near_parallel_axis(
|
||||
self,
|
||||
base_axis: np.ndarray,
|
||||
perturbation_scale: float = 0.05,
|
||||
) -> np.ndarray:
|
||||
"""Return *base_axis* with a small random perturbation, re-normalized."""
|
||||
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
|
||||
return perturbed / np.linalg.norm(perturbed)
|
||||
|
||||
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
|
||||
"""Sample a joint axis using the specified strategy."""
|
||||
if strategy == "cardinal":
|
||||
return self._cardinal_axis()
|
||||
if strategy == "near_parallel":
|
||||
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
|
||||
return self._random_axis()
|
||||
|
||||
def _resolve_axis(
|
||||
self,
|
||||
strategy: AxisStrategy,
|
||||
parallel_axis_prob: float,
|
||||
shared_axis: np.ndarray | None,
|
||||
) -> tuple[np.ndarray, np.ndarray | None]:
|
||||
"""Return (axis_for_this_joint, shared_axis_to_propagate).
|
||||
|
||||
On the first call where *shared_axis* is ``None`` and parallel
|
||||
injection triggers, a base axis is chosen and returned as
|
||||
*shared_axis* for subsequent calls.
|
||||
"""
|
||||
if shared_axis is not None:
|
||||
return self._near_parallel_axis(shared_axis), shared_axis
|
||||
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
|
||||
base = self._sample_axis(strategy)
|
||||
return base.copy(), base
|
||||
return self._sample_axis(strategy), None
|
||||
|
||||
def _select_joint_type(
|
||||
self,
|
||||
joint_types: JointType | list[JointType],
|
||||
) -> JointType:
|
||||
"""Select a joint type from a single type or list."""
|
||||
if isinstance(joint_types, list):
|
||||
idx = int(self.rng.integers(0, len(joint_types)))
|
||||
return joint_types[idx]
|
||||
return joint_types
|
||||
|
||||
def _create_joint(
|
||||
self,
|
||||
joint_id: int,
|
||||
body_a_id: int,
|
||||
body_b_id: int,
|
||||
pos_a: np.ndarray,
|
||||
pos_b: np.ndarray,
|
||||
joint_type: JointType,
|
||||
*,
|
||||
axis: np.ndarray | None = None,
|
||||
orient_a: np.ndarray | None = None,
|
||||
orient_b: np.ndarray | None = None,
|
||||
) -> Joint:
|
||||
"""Create a joint between two bodies.
|
||||
|
||||
When orientations are provided, anchor points are offset from
|
||||
each body's center along a random local direction rotated into
|
||||
world frame, rather than placed at the midpoint.
|
||||
"""
|
||||
if orient_a is not None and orient_b is not None:
|
||||
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
|
||||
offset_scale = dist * 0.2
|
||||
local_a = self.rng.standard_normal(3) * offset_scale
|
||||
local_b = self.rng.standard_normal(3) * offset_scale
|
||||
anchor_a = pos_a + orient_a @ local_a
|
||||
anchor_b = pos_b + orient_b @ local_b
|
||||
else:
|
||||
anchor = (pos_a + pos_b) / 2.0
|
||||
anchor_a = anchor
|
||||
anchor_b = anchor
|
||||
|
||||
return Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=body_a_id,
|
||||
body_b=body_b_id,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor_a,
|
||||
anchor_b=anchor_b,
|
||||
axis=axis if axis is not None else self._random_axis(),
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_chain_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_type: JointType = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a serial kinematic chain.
|
||||
|
||||
Simple but useful: each body connects to the next with the
|
||||
specified joint type. Results in an underconstrained assembly
|
||||
(serial chain is never rigid without closing loops).
|
||||
"""
|
||||
bodies = []
|
||||
joints = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
pos = np.array([i * 2.0, 0.0, 0.0])
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=pos,
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies - 1):
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
i + 1,
|
||||
bodies[i].position,
|
||||
bodies[i + 1].position,
|
||||
joint_type,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[i + 1].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_rigid_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a minimally rigid assembly by adding joints until rigid.
|
||||
|
||||
Strategy: start with fixed joints on a spanning tree (guarantees
|
||||
rigidity), then randomly relax some to weaker joint types while
|
||||
maintaining rigidity via the pebble game check.
|
||||
"""
|
||||
bodies = []
|
||||
for i in range(n_bodies):
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Build spanning tree with fixed joints (overconstrained)
|
||||
joints: list[Joint] = []
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
JointType.FIXED,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
# Try relaxing joints to weaker types while maintaining rigidity
|
||||
weaker_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.BALL,
|
||||
]
|
||||
|
||||
ground = 0 if grounded else None
|
||||
for idx in self.rng.permutation(len(joints)):
|
||||
original_type = joints[idx].joint_type
|
||||
for candidate in weaker_types:
|
||||
joints[idx].joint_type = candidate
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
if analysis.is_rigid:
|
||||
break # Keep the weaker type
|
||||
else:
|
||||
joints[idx].joint_type = original_type
|
||||
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_overconstrained_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
extra_joints: int = 2,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate an assembly with known redundant constraints.
|
||||
|
||||
Starts with a rigid assembly, then adds extra joints that
|
||||
the pebble game will flag as redundant.
|
||||
"""
|
||||
bodies, joints, _ = self.generate_rigid_assembly(
|
||||
n_bodies,
|
||||
grounded=grounded,
|
||||
axis_strategy=axis_strategy,
|
||||
parallel_axis_prob=parallel_axis_prob,
|
||||
)
|
||||
|
||||
joint_id = len(joints)
|
||||
shared_axis: np.ndarray | None = None
|
||||
for _ in range(extra_joints):
|
||||
a, b = self.rng.choice(n_bodies, size=2, replace=False)
|
||||
_overcon_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.FIXED,
|
||||
JointType.BALL,
|
||||
]
|
||||
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
int(a),
|
||||
int(b),
|
||||
bodies[int(a)].position,
|
||||
bodies[int(b)].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[int(a)].orientation,
|
||||
orient_b=bodies[int(b)].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
ground = 0 if grounded else None
|
||||
analysis = analyze_assembly(bodies, joints, ground_body=ground)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_tree_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
branching_factor: int = 3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a random tree topology with configurable branching.
|
||||
|
||||
Creates a tree where each body can have up to *branching_factor*
|
||||
children. Different branches can use different joint types if a
|
||||
list is provided. Always underconstrained (no closed loops).
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies (root + children).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
branching_factor: Max children per parent (1-5 recommended).
|
||||
"""
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
available_parents = [0]
|
||||
next_id = 1
|
||||
joint_id = 0
|
||||
shared_axis: np.ndarray | None = None
|
||||
|
||||
while next_id < n_bodies and available_parents:
|
||||
pidx = int(self.rng.integers(0, len(available_parents)))
|
||||
parent_id = available_parents[pidx]
|
||||
parent_pos = bodies[parent_id].position
|
||||
|
||||
max_children = min(branching_factor, n_bodies - next_id)
|
||||
n_children = int(self.rng.integers(1, max_children + 1))
|
||||
|
||||
for _ in range(n_children):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(1.5, 3.0)
|
||||
child_pos = parent_pos + direction * distance
|
||||
child_orient = self._random_orientation()
|
||||
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=next_id,
|
||||
position=child_pos,
|
||||
orientation=child_orient,
|
||||
)
|
||||
)
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent_id,
|
||||
next_id,
|
||||
parent_pos,
|
||||
child_pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent_id].orientation,
|
||||
orient_b=child_orient,
|
||||
)
|
||||
)
|
||||
|
||||
available_parents.append(next_id)
|
||||
next_id += 1
|
||||
joint_id += 1
|
||||
if next_id >= n_bodies:
|
||||
break
|
||||
|
||||
# Retire parent if it reached branching limit or randomly
|
||||
if n_children >= branching_factor or self.rng.random() < 0.3:
|
||||
available_parents.pop(pidx)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_loop_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a single closed loop (ring) of bodies.
|
||||
|
||||
The closing constraint introduces redundancy, making this
|
||||
useful for generating overconstrained training data.
|
||||
|
||||
Args:
|
||||
n_bodies: Bodies in the ring (>= 3).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 3.
|
||||
"""
|
||||
if n_bodies < 3:
|
||||
msg = "Loop assembly requires at least 3 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
base_radius = max(2.0, n_bodies * 0.4)
|
||||
for i in range(n_bodies):
|
||||
angle = 2 * np.pi * i / n_bodies
|
||||
radius = base_radius + self.rng.uniform(-0.5, 0.5)
|
||||
x = radius * np.cos(angle)
|
||||
y = radius * np.sin(angle)
|
||||
z = float(self.rng.uniform(-0.2, 0.2))
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=np.array([x, y, z]),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(n_bodies):
|
||||
next_i = (i + 1) % n_bodies
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i,
|
||||
i,
|
||||
next_i,
|
||||
bodies[i].position,
|
||||
bodies[next_i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[i].orientation,
|
||||
orient_b=bodies[next_i].orientation,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_star_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a star topology with central hub and satellites.
|
||||
|
||||
Body 0 is the hub; all other bodies connect directly to it.
|
||||
Underconstrained because there are no inter-satellite connections.
|
||||
|
||||
Args:
|
||||
n_bodies: Total bodies including hub (>= 2).
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
|
||||
Raises:
|
||||
ValueError: If *n_bodies* < 2.
|
||||
"""
|
||||
if n_bodies < 2:
|
||||
msg = "Star assembly requires at least 2 bodies"
|
||||
raise ValueError(msg)
|
||||
|
||||
hub_orient = self._random_orientation()
|
||||
bodies: list[RigidBody] = [
|
||||
RigidBody(
|
||||
body_id=0,
|
||||
position=np.zeros(3),
|
||||
orientation=hub_orient,
|
||||
)
|
||||
]
|
||||
joints: list[Joint] = []
|
||||
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
direction = self._random_axis()
|
||||
distance = self.rng.uniform(2.0, 5.0)
|
||||
pos = direction * distance
|
||||
sat_orient = self._random_orientation()
|
||||
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
|
||||
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
i - 1,
|
||||
0,
|
||||
i,
|
||||
np.zeros(3),
|
||||
pos,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=hub_orient,
|
||||
orient_b=sat_orient,
|
||||
)
|
||||
)
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
def generate_mixed_assembly(
|
||||
self,
|
||||
n_bodies: int,
|
||||
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
|
||||
edge_density: float = 0.3,
|
||||
*,
|
||||
grounded: bool = True,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
|
||||
"""Generate a mixed topology combining tree and loop elements.
|
||||
|
||||
Builds a spanning tree for connectivity, then adds extra edges
|
||||
based on *edge_density* to create loops and redundancy.
|
||||
|
||||
Args:
|
||||
n_bodies: Number of bodies.
|
||||
joint_types: Single type or list to sample from per joint.
|
||||
edge_density: Fraction of non-tree edges to add (0.0-1.0).
|
||||
|
||||
Raises:
|
||||
ValueError: If *edge_density* not in [0.0, 1.0].
|
||||
"""
|
||||
if not 0.0 <= edge_density <= 1.0:
|
||||
msg = "edge_density must be in [0.0, 1.0]"
|
||||
raise ValueError(msg)
|
||||
|
||||
bodies: list[RigidBody] = []
|
||||
joints: list[Joint] = []
|
||||
|
||||
for i in range(n_bodies):
|
||||
bodies.append(
|
||||
RigidBody(
|
||||
body_id=i,
|
||||
position=self._random_position(),
|
||||
orientation=self._random_orientation(),
|
||||
)
|
||||
)
|
||||
|
||||
# Phase 1: spanning tree
|
||||
joint_id = 0
|
||||
existing_edges: set[frozenset[int]] = set()
|
||||
shared_axis: np.ndarray | None = None
|
||||
for i in range(1, n_bodies):
|
||||
parent = int(self.rng.integers(0, i))
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
parent,
|
||||
i,
|
||||
bodies[parent].position,
|
||||
bodies[i].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[parent].orientation,
|
||||
orient_b=bodies[i].orientation,
|
||||
)
|
||||
)
|
||||
existing_edges.add(frozenset([parent, i]))
|
||||
joint_id += 1
|
||||
|
||||
# Phase 2: extra edges based on density
|
||||
candidates: list[tuple[int, int]] = []
|
||||
for i in range(n_bodies):
|
||||
for j in range(i + 1, n_bodies):
|
||||
if frozenset([i, j]) not in existing_edges:
|
||||
candidates.append((i, j))
|
||||
|
||||
n_extra = int(edge_density * len(candidates))
|
||||
self.rng.shuffle(candidates)
|
||||
|
||||
for a, b in candidates[:n_extra]:
|
||||
jtype = self._select_joint_type(joint_types)
|
||||
axis, shared_axis = self._resolve_axis(
|
||||
axis_strategy,
|
||||
parallel_axis_prob,
|
||||
shared_axis,
|
||||
)
|
||||
joints.append(
|
||||
self._create_joint(
|
||||
joint_id,
|
||||
a,
|
||||
b,
|
||||
bodies[a].position,
|
||||
bodies[b].position,
|
||||
jtype,
|
||||
axis=axis,
|
||||
orient_a=bodies[a].orientation,
|
||||
orient_b=bodies[b].orientation,
|
||||
)
|
||||
)
|
||||
joint_id += 1
|
||||
|
||||
analysis = analyze_assembly(
|
||||
bodies,
|
||||
joints,
|
||||
ground_body=0 if grounded else None,
|
||||
)
|
||||
return bodies, joints, analysis
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Batch generation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def generate_training_batch(
|
||||
self,
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] | None = None,
|
||||
complexity_tier: ComplexityTier | None = None,
|
||||
*,
|
||||
axis_strategy: AxisStrategy = "random",
|
||||
parallel_axis_prob: float = 0.0,
|
||||
grounded_ratio: float = 1.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Generate a batch of labeled training examples.
|
||||
|
||||
Each example contains body positions, joint descriptions,
|
||||
per-joint independence labels, and assembly-level classification.
|
||||
|
||||
Args:
|
||||
batch_size: Number of assemblies to generate.
|
||||
n_bodies_range: ``(min, max_exclusive)`` body count.
|
||||
Overridden by *complexity_tier* when both are given.
|
||||
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
|
||||
/ ``"complex"``). Overrides *n_bodies_range*.
|
||||
axis_strategy: Axis sampling strategy for joint axes.
|
||||
parallel_axis_prob: Probability of parallel axis injection.
|
||||
grounded_ratio: Fraction of examples that are grounded.
|
||||
"""
|
||||
if complexity_tier is not None:
|
||||
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
|
||||
elif n_bodies_range is None:
|
||||
n_bodies_range = (3, 8)
|
||||
|
||||
_joint_pool = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
JointType.FIXED,
|
||||
]
|
||||
|
||||
geo_kw: dict[str, Any] = {
|
||||
"axis_strategy": axis_strategy,
|
||||
"parallel_axis_prob": parallel_axis_prob,
|
||||
}
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
n = int(self.rng.integers(*n_bodies_range))
|
||||
gen_idx = int(self.rng.integers(7))
|
||||
grounded = bool(self.rng.random() < grounded_ratio)
|
||||
|
||||
if gen_idx == 0:
|
||||
_chain_types = [
|
||||
JointType.REVOLUTE,
|
||||
JointType.BALL,
|
||||
JointType.CYLINDRICAL,
|
||||
]
|
||||
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
|
||||
bodies, joints, analysis = self.generate_chain_assembly(
|
||||
n,
|
||||
jtype,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "chain"
|
||||
elif gen_idx == 1:
|
||||
bodies, joints, analysis = self.generate_rigid_assembly(
|
||||
n,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "rigid"
|
||||
elif gen_idx == 2:
|
||||
extra = int(self.rng.integers(1, 4))
|
||||
bodies, joints, analysis = self.generate_overconstrained_assembly(
|
||||
n,
|
||||
extra,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "overconstrained"
|
||||
elif gen_idx == 3:
|
||||
branching = int(self.rng.integers(2, 5))
|
||||
bodies, joints, analysis = self.generate_tree_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
branching,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "tree"
|
||||
elif gen_idx == 4:
|
||||
n = max(n, 3)
|
||||
bodies, joints, analysis = self.generate_loop_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "loop"
|
||||
elif gen_idx == 5:
|
||||
n = max(n, 2)
|
||||
bodies, joints, analysis = self.generate_star_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "star"
|
||||
else:
|
||||
density = float(self.rng.uniform(0.2, 0.5))
|
||||
bodies, joints, analysis = self.generate_mixed_assembly(
|
||||
n,
|
||||
_joint_pool,
|
||||
density,
|
||||
grounded=grounded,
|
||||
**geo_kw,
|
||||
)
|
||||
gen_name = "mixed"
|
||||
|
||||
# Produce ground truth labels (includes ConstraintAnalysis)
|
||||
ground = 0 if grounded else None
|
||||
labels = label_assembly(bodies, joints, ground_body=ground)
|
||||
analysis = labels.analysis
|
||||
|
||||
# Build per-joint labels from edge results
|
||||
joint_labels: dict[int, dict[str, int]] = {}
|
||||
for result in analysis.per_edge_results:
|
||||
jid = result["joint_id"]
|
||||
if jid not in joint_labels:
|
||||
joint_labels[jid] = {
|
||||
"independent_constraints": 0,
|
||||
"redundant_constraints": 0,
|
||||
"total_constraints": 0,
|
||||
}
|
||||
joint_labels[jid]["total_constraints"] += 1
|
||||
if result["independent"]:
|
||||
joint_labels[jid]["independent_constraints"] += 1
|
||||
else:
|
||||
joint_labels[jid]["redundant_constraints"] += 1
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"example_id": i,
|
||||
"generator_type": gen_name,
|
||||
"grounded": grounded,
|
||||
"n_bodies": len(bodies),
|
||||
"n_joints": len(joints),
|
||||
"body_positions": [b.position.tolist() for b in bodies],
|
||||
"body_orientations": [b.orientation.tolist() for b in bodies],
|
||||
"joints": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"body_a": j.body_a,
|
||||
"body_b": j.body_b,
|
||||
"type": j.joint_type.name,
|
||||
"axis": j.axis.tolist(),
|
||||
}
|
||||
for j in joints
|
||||
],
|
||||
"joint_labels": joint_labels,
|
||||
"labels": labels.to_dict(),
|
||||
"assembly_classification": (analysis.combinatorial_classification),
|
||||
"is_rigid": analysis.is_rigid,
|
||||
"is_minimally_rigid": analysis.is_minimally_rigid,
|
||||
"internal_dof": analysis.jacobian_internal_dof,
|
||||
"geometric_degeneracies": (analysis.geometric_degeneracies),
|
||||
}
|
||||
)
|
||||
|
||||
return examples
|
||||
517
solver/datagen/jacobian.py
Normal file
517
solver/datagen/jacobian.py
Normal file
@@ -0,0 +1,517 @@
|
||||
"""Numerical Jacobian rank verification for assembly constraint analysis.
|
||||
|
||||
Builds the constraint Jacobian matrix and analyzes its numerical rank
|
||||
to detect geometric degeneracies that the combinatorial pebble game
|
||||
cannot identify (e.g., parallel revolute axes creating hidden dependencies).
|
||||
|
||||
References:
|
||||
- Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D"
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["JacobianVerifier"]
|
||||
|
||||
|
||||
class JacobianVerifier:
|
||||
"""Builds and analyzes the constraint Jacobian for numerical rank check.
|
||||
|
||||
The pebble game gives a combinatorial *necessary* condition for
|
||||
rigidity. However, geometric special cases (e.g., all revolute axes
|
||||
parallel, creating a hidden dependency) require numerical verification.
|
||||
|
||||
For each joint, we construct the constraint Jacobian rows that map
|
||||
the 6n-dimensional generalized velocity vector to the constraint
|
||||
violation rates. The rank of this Jacobian equals the number of
|
||||
truly independent constraints.
|
||||
|
||||
The generalized velocity vector for n bodies is::
|
||||
|
||||
v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z]
|
||||
|
||||
Each scalar constraint C_i contributes one row to J such that::
|
||||
|
||||
dC_i/dt = J_i @ v = 0
|
||||
"""
|
||||
|
||||
def __init__(self, bodies: list[RigidBody]) -> None:
|
||||
self.bodies = {b.body_id: b for b in bodies}
|
||||
self.body_index = {b.body_id: i for i, b in enumerate(bodies)}
|
||||
self.n_bodies = len(bodies)
|
||||
self.jacobian_rows: list[np.ndarray] = []
|
||||
self.row_labels: list[dict[str, Any]] = []
|
||||
|
||||
def _body_cols(self, body_id: int) -> tuple[int, int]:
|
||||
"""Return the column range [start, end) for a body in J."""
|
||||
idx = self.body_index[body_id]
|
||||
return idx * 6, (idx + 1) * 6
|
||||
|
||||
def add_joint_constraints(self, joint: Joint) -> int:
|
||||
"""Add Jacobian rows for all scalar constraints of a joint.
|
||||
|
||||
Returns the number of rows added.
|
||||
"""
|
||||
builder = {
|
||||
JointType.FIXED: self._build_fixed,
|
||||
JointType.REVOLUTE: self._build_revolute,
|
||||
JointType.CYLINDRICAL: self._build_cylindrical,
|
||||
JointType.SLIDER: self._build_slider,
|
||||
JointType.BALL: self._build_ball,
|
||||
JointType.PLANAR: self._build_planar,
|
||||
JointType.DISTANCE: self._build_distance,
|
||||
JointType.PARALLEL: self._build_parallel,
|
||||
JointType.PERPENDICULAR: self._build_perpendicular,
|
||||
JointType.UNIVERSAL: self._build_universal,
|
||||
JointType.SCREW: self._build_screw,
|
||||
}
|
||||
|
||||
rows_before = len(self.jacobian_rows)
|
||||
builder[joint.joint_type](joint)
|
||||
return len(self.jacobian_rows) - rows_before
|
||||
|
||||
def _make_row(self) -> np.ndarray:
|
||||
"""Create a zero row of width 6*n_bodies."""
|
||||
return np.zeros(6 * self.n_bodies)
|
||||
|
||||
def _skew(self, v: np.ndarray) -> np.ndarray:
|
||||
"""Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``."""
|
||||
return np.array(
|
||||
[
|
||||
[0, -v[2], v[1]],
|
||||
[v[2], 0, -v[0]],
|
||||
[-v[1], v[0], 0],
|
||||
]
|
||||
)
|
||||
|
||||
# --- Ball-and-socket (spherical) joint: 3 translation constraints ---
|
||||
|
||||
def _build_ball(self, joint: Joint) -> None:
|
||||
"""Ball joint: coincident point constraint.
|
||||
|
||||
``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0``
|
||||
(3 equations)
|
||||
|
||||
Jacobian rows (for each of x, y, z):
|
||||
body_a linear: -I
|
||||
body_a angular: +skew(R_a @ r_a)
|
||||
body_b linear: +I
|
||||
body_b angular: -skew(R_b @ r_b)
|
||||
"""
|
||||
# Use anchor positions directly as world-frame offsets
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
e = np.zeros(3)
|
||||
e[axis_idx] = 1.0
|
||||
|
||||
row[col_a_start : col_a_start + 3] = -e
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, e)
|
||||
|
||||
row[col_b_start : col_b_start + 3] = e
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "ball_translation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Fixed joint: 3 translation + 3 rotation constraints ---
|
||||
|
||||
def _build_fixed(self, joint: Joint) -> None:
|
||||
"""Fixed joint = ball joint + locked rotation.
|
||||
|
||||
Translation part: same as ball joint (3 rows).
|
||||
Rotation part: relative angular velocity must be zero (3 rows).
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "fixed_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Revolute (hinge) joint: 3 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_revolute(self, joint: Joint) -> None:
|
||||
"""Revolute joint: rotation only about one axis.
|
||||
|
||||
Translation: same as ball (3 rows).
|
||||
Rotation: relative angular velocity must be parallel to hinge axis.
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "revolute_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Cylindrical joint: 2 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_cylindrical(self, joint: Joint) -> None:
|
||||
"""Cylindrical joint: allows rotation + translation along one axis.
|
||||
|
||||
Translation: constrain motion perpendicular to axis (2 rows).
|
||||
Rotation: constrain rotation perpendicular to axis (2 rows).
|
||||
"""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -t
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||
row[col_b_start : col_b_start + 3] = t
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "cylindrical_translation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "cylindrical_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Slider (prismatic) joint: 2 translation + 3 rotation constraints ---
|
||||
|
||||
def _build_slider(self, joint: Joint) -> None:
|
||||
"""Slider/prismatic joint: translation along one axis only.
|
||||
|
||||
Translation: perpendicular translation constrained (2 rows).
|
||||
Rotation: all relative rotation constrained (3 rows).
|
||||
"""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(axis)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -t
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
|
||||
row[col_b_start : col_b_start + 3] = t
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "slider_translation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "slider_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Planar joint: 1 translation + 2 rotation constraints ---
|
||||
|
||||
def _build_planar(self, joint: Joint) -> None:
|
||||
"""Planar joint: constrains to a plane.
|
||||
|
||||
Translation: motion along plane normal constrained (1 row).
|
||||
Rotation: rotation about axes in the plane constrained (2 rows).
|
||||
"""
|
||||
normal = joint.axis / np.linalg.norm(joint.axis)
|
||||
t1, t2 = self._perpendicular_pair(normal)
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -normal
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal)
|
||||
row[col_b_start : col_b_start + 3] = normal
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal)
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "planar_translation",
|
||||
}
|
||||
)
|
||||
|
||||
for i, t in enumerate((t1, t2)):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -t
|
||||
row[col_b_start + 3 : col_b_start + 6] = t
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "planar_rotation",
|
||||
"perp_axis": i,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Distance constraint: 1 scalar ---
|
||||
|
||||
def _build_distance(self, joint: Joint) -> None:
|
||||
"""Distance constraint: ``||p_b - p_a|| = d``.
|
||||
|
||||
Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0``
|
||||
where ``direction = normalized(p_b - p_a)``.
|
||||
"""
|
||||
p_a = joint.anchor_a
|
||||
p_b = joint.anchor_b
|
||||
diff = p_b - p_a
|
||||
dist = np.linalg.norm(diff)
|
||||
direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist
|
||||
|
||||
r_a = joint.anchor_a - self.bodies[joint.body_a].position
|
||||
r_b = joint.anchor_b - self.bodies[joint.body_b].position
|
||||
|
||||
col_a_start, col_a_end = self._body_cols(joint.body_a)
|
||||
col_b_start, col_b_end = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -direction
|
||||
row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction)
|
||||
row[col_b_start : col_b_start + 3] = direction
|
||||
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction)
|
||||
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "distance",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Parallel constraint: 3 rotation constraints ---
|
||||
|
||||
def _build_parallel(self, joint: Joint) -> None:
|
||||
"""Parallel: all relative rotation constrained (same as fixed rotation).
|
||||
|
||||
In practice only 2 of 3 are independent for a single axis, but
|
||||
we emit 3 and let the rank check sort it out.
|
||||
"""
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
for axis_idx in range(3):
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 + axis_idx] = -1.0
|
||||
row[col_b_start + 3 + axis_idx] = 1.0
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "parallel_rotation",
|
||||
"axis": axis_idx,
|
||||
}
|
||||
)
|
||||
|
||||
# --- Perpendicular constraint: 1 angular ---
|
||||
|
||||
def _build_perpendicular(self, joint: Joint) -> None:
|
||||
"""Perpendicular: single dot-product angular constraint."""
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "perpendicular",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Universal (Cardan) joint: 3 translation + 1 rotation ---
|
||||
|
||||
def _build_universal(self, joint: Joint) -> None:
|
||||
"""Universal joint: ball + one rotation constraint.
|
||||
|
||||
Allows rotation about two axes, constrains rotation about the third.
|
||||
"""
|
||||
self._build_ball(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start + 3 : col_a_start + 6] = -axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "universal_rotation",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled ---
|
||||
|
||||
def _build_screw(self, joint: Joint) -> None:
|
||||
"""Screw joint: coupled rotation-translation along axis.
|
||||
|
||||
Like cylindrical but with a coupling constraint:
|
||||
``v_axial - pitch * w_axial = 0``
|
||||
"""
|
||||
self._build_cylindrical(joint)
|
||||
|
||||
axis = joint.axis / np.linalg.norm(joint.axis)
|
||||
col_a_start, _ = self._body_cols(joint.body_a)
|
||||
col_b_start, _ = self._body_cols(joint.body_b)
|
||||
|
||||
row = self._make_row()
|
||||
row[col_a_start : col_a_start + 3] = -axis
|
||||
row[col_b_start : col_b_start + 3] = axis
|
||||
row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis
|
||||
row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis
|
||||
self.jacobian_rows.append(row)
|
||||
self.row_labels.append(
|
||||
{
|
||||
"joint_id": joint.joint_id,
|
||||
"type": "screw_coupling",
|
||||
}
|
||||
)
|
||||
|
||||
# --- Utilities ---
|
||||
|
||||
def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Generate two unit vectors perpendicular to *axis* and each other."""
|
||||
if abs(axis[0]) < 0.9:
|
||||
t1 = np.cross(axis, np.array([1.0, 0, 0]))
|
||||
else:
|
||||
t1 = np.cross(axis, np.array([0, 1.0, 0]))
|
||||
t1 /= np.linalg.norm(t1)
|
||||
t2 = np.cross(axis, t1)
|
||||
t2 /= np.linalg.norm(t2)
|
||||
return t1, t2
|
||||
|
||||
def get_jacobian(self) -> np.ndarray:
|
||||
"""Return the full constraint Jacobian matrix."""
|
||||
if not self.jacobian_rows:
|
||||
return np.zeros((0, 6 * self.n_bodies))
|
||||
return np.array(self.jacobian_rows)
|
||||
|
||||
def numerical_rank(self, tol: float = 1e-8) -> int:
|
||||
"""Compute the numerical rank of the constraint Jacobian via SVD.
|
||||
|
||||
This is the number of truly independent scalar constraints,
|
||||
accounting for geometric degeneracies that the combinatorial
|
||||
pebble game cannot detect.
|
||||
"""
|
||||
j = self.get_jacobian()
|
||||
if j.size == 0:
|
||||
return 0
|
||||
sv = np.linalg.svd(j, compute_uv=False)
|
||||
return int(np.sum(sv > tol))
|
||||
|
||||
def find_dependencies(self, tol: float = 1e-8) -> list[int]:
|
||||
"""Identify which constraint rows are numerically dependent.
|
||||
|
||||
Returns indices of rows that can be removed without changing
|
||||
the Jacobian's rank.
|
||||
"""
|
||||
j = self.get_jacobian()
|
||||
if j.size == 0:
|
||||
return []
|
||||
|
||||
n_rows = j.shape[0]
|
||||
dependent: list[int] = []
|
||||
|
||||
current = np.zeros((0, j.shape[1]))
|
||||
current_rank = 0
|
||||
|
||||
for i in range(n_rows):
|
||||
candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :]
|
||||
sv = np.linalg.svd(candidate, compute_uv=False)
|
||||
new_rank = int(np.sum(sv > tol))
|
||||
|
||||
if new_rank > current_rank:
|
||||
current = candidate
|
||||
current_rank = new_rank
|
||||
else:
|
||||
dependent.append(i)
|
||||
|
||||
return dependent
|
||||
394
solver/datagen/labeling.py
Normal file
394
solver/datagen/labeling.py
Normal file
@@ -0,0 +1,394 @@
|
||||
"""Ground truth labeling pipeline for synthetic assemblies.
|
||||
|
||||
Produces rich per-constraint, per-joint, per-body, and assembly-level
|
||||
labels by running both the pebble game and Jacobian verification and
|
||||
correlating their results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["AssemblyLabels", "label_assembly"]
|
||||
|
||||
_GROUND_ID = -1
|
||||
_SVD_TOL = 1e-8
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Label dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintLabel:
|
||||
"""Per scalar-constraint label combining both analysis methods."""
|
||||
|
||||
joint_id: int
|
||||
constraint_idx: int
|
||||
pebble_independent: bool
|
||||
jacobian_independent: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class JointLabel:
|
||||
"""Aggregated constraint counts for a single joint."""
|
||||
|
||||
joint_id: int
|
||||
independent_count: int
|
||||
redundant_count: int
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class BodyDofLabel:
|
||||
"""Per-body DOF signature from nullspace projection."""
|
||||
|
||||
body_id: int
|
||||
translational_dof: int
|
||||
rotational_dof: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyLabel:
|
||||
"""Assembly-wide summary label."""
|
||||
|
||||
classification: str
|
||||
total_dof: int
|
||||
redundant_count: int
|
||||
is_rigid: bool
|
||||
is_minimally_rigid: bool
|
||||
has_degeneracy: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class AssemblyLabels:
|
||||
"""Complete ground truth labels for an assembly."""
|
||||
|
||||
per_constraint: list[ConstraintLabel]
|
||||
per_joint: list[JointLabel]
|
||||
per_body: list[BodyDofLabel]
|
||||
assembly: AssemblyLabel
|
||||
analysis: ConstraintAnalysis
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"per_constraint": [
|
||||
{
|
||||
"joint_id": c.joint_id,
|
||||
"constraint_idx": c.constraint_idx,
|
||||
"pebble_independent": c.pebble_independent,
|
||||
"jacobian_independent": c.jacobian_independent,
|
||||
}
|
||||
for c in self.per_constraint
|
||||
],
|
||||
"per_joint": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"independent_count": j.independent_count,
|
||||
"redundant_count": j.redundant_count,
|
||||
"total": j.total,
|
||||
}
|
||||
for j in self.per_joint
|
||||
],
|
||||
"per_body": [
|
||||
{
|
||||
"body_id": b.body_id,
|
||||
"translational_dof": b.translational_dof,
|
||||
"rotational_dof": b.rotational_dof,
|
||||
}
|
||||
for b in self.per_body
|
||||
],
|
||||
"assembly": {
|
||||
"classification": self.assembly.classification,
|
||||
"total_dof": self.assembly.total_dof,
|
||||
"redundant_count": self.assembly.redundant_count,
|
||||
"is_rigid": self.assembly.is_rigid,
|
||||
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
||||
"has_degeneracy": self.assembly.has_degeneracy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-body DOF from nullspace projection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_per_body_dof(
|
||||
j_reduced: np.ndarray,
|
||||
body_ids: list[int],
|
||||
ground_body: int | None,
|
||||
body_index: dict[int, int],
|
||||
) -> list[BodyDofLabel]:
|
||||
"""Compute translational and rotational DOF per body.
|
||||
|
||||
Uses SVD nullspace projection: for each body, extract its
|
||||
translational (3 cols) and rotational (3 cols) components
|
||||
from the nullspace basis and compute ranks.
|
||||
"""
|
||||
# Build column index mapping for the reduced Jacobian
|
||||
# (ground body columns have been removed)
|
||||
col_map: dict[int, int] = {}
|
||||
col_idx = 0
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
continue
|
||||
col_map[bid] = col_idx
|
||||
col_idx += 1
|
||||
|
||||
results: list[BodyDofLabel] = []
|
||||
|
||||
if j_reduced.size == 0:
|
||||
# No constraints — every body is fully free
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
else:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3))
|
||||
return results
|
||||
|
||||
# Full SVD to get nullspace
|
||||
_u, s, vh = np.linalg.svd(j_reduced, full_matrices=True)
|
||||
rank = int(np.sum(s > _SVD_TOL))
|
||||
n_cols = j_reduced.shape[1]
|
||||
|
||||
if rank >= n_cols:
|
||||
# Fully constrained — no nullspace
|
||||
for bid in body_ids:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
return results
|
||||
|
||||
# Nullspace basis: rows of Vh beyond the rank
|
||||
nullspace = vh[rank:] # shape: (n_cols - rank, n_cols)
|
||||
|
||||
for bid in body_ids:
|
||||
if bid == ground_body:
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
|
||||
continue
|
||||
|
||||
idx = col_map[bid]
|
||||
trans_cols = nullspace[:, idx * 6 : idx * 6 + 3]
|
||||
rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6]
|
||||
|
||||
# Rank of each block = DOF in that category
|
||||
if trans_cols.size > 0:
|
||||
sv_t = np.linalg.svd(trans_cols, compute_uv=False)
|
||||
t_dof = int(np.sum(sv_t > _SVD_TOL))
|
||||
else:
|
||||
t_dof = 0
|
||||
|
||||
if rot_cols.size > 0:
|
||||
sv_r = np.linalg.svd(rot_cols, compute_uv=False)
|
||||
r_dof = int(np.sum(sv_r > _SVD_TOL))
|
||||
else:
|
||||
r_dof = 0
|
||||
|
||||
results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Main labeling function
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def label_assembly(
|
||||
bodies: list[RigidBody],
|
||||
joints: list[Joint],
|
||||
ground_body: int | None = None,
|
||||
) -> AssemblyLabels:
|
||||
"""Produce complete ground truth labels for an assembly.
|
||||
|
||||
Runs both the pebble game and Jacobian verification internally,
|
||||
then correlates their results into per-constraint, per-joint,
|
||||
per-body, and assembly-level labels.
|
||||
|
||||
Args:
|
||||
bodies: Rigid bodies in the assembly.
|
||||
joints: Joints connecting the bodies.
|
||||
ground_body: If set, this body is fixed to the world.
|
||||
|
||||
Returns:
|
||||
AssemblyLabels with full label set and embedded ConstraintAnalysis.
|
||||
"""
|
||||
# ---- Pebble Game ----
|
||||
pg = PebbleGame3D()
|
||||
all_edge_results: list[dict[str, Any]] = []
|
||||
|
||||
if ground_body is not None:
|
||||
pg.add_body(_GROUND_ID)
|
||||
|
||||
for body in bodies:
|
||||
pg.add_body(body.body_id)
|
||||
|
||||
if ground_body is not None:
|
||||
ground_joint = Joint(
|
||||
joint_id=-1,
|
||||
body_a=ground_body,
|
||||
body_b=_GROUND_ID,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=bodies[0].position if bodies else np.zeros(3),
|
||||
anchor_b=bodies[0].position if bodies else np.zeros(3),
|
||||
)
|
||||
pg.add_joint(ground_joint)
|
||||
|
||||
for joint in joints:
|
||||
results = pg.add_joint(joint)
|
||||
all_edge_results.extend(results)
|
||||
|
||||
grounded = ground_body is not None
|
||||
combinatorial_independent = len(pg.state.independent_edges)
|
||||
raw_dof = pg.get_dof()
|
||||
ground_offset = 6 if grounded else 0
|
||||
effective_dof = raw_dof - ground_offset
|
||||
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
|
||||
|
||||
redundant_count = pg.get_redundant_count()
|
||||
if redundant_count > 0 and effective_internal_dof > 0:
|
||||
classification = "mixed"
|
||||
elif redundant_count > 0:
|
||||
classification = "overconstrained"
|
||||
elif effective_internal_dof > 0:
|
||||
classification = "underconstrained"
|
||||
else:
|
||||
classification = "well-constrained"
|
||||
|
||||
# ---- Jacobian Verification ----
|
||||
verifier = JacobianVerifier(bodies)
|
||||
|
||||
for joint in joints:
|
||||
verifier.add_joint_constraints(joint)
|
||||
|
||||
j_full = verifier.get_jacobian()
|
||||
j_reduced = j_full.copy()
|
||||
if ground_body is not None and j_reduced.size > 0:
|
||||
idx = verifier.body_index[ground_body]
|
||||
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
|
||||
j_reduced = np.delete(j_reduced, cols_to_remove, axis=1)
|
||||
|
||||
if j_reduced.size > 0:
|
||||
sv = np.linalg.svd(j_reduced, compute_uv=False)
|
||||
jacobian_rank = int(np.sum(sv > _SVD_TOL))
|
||||
else:
|
||||
jacobian_rank = 0
|
||||
|
||||
n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies)
|
||||
jacobian_nullity = n_cols - jacobian_rank
|
||||
dependent_rows = verifier.find_dependencies()
|
||||
dependent_set = set(dependent_rows)
|
||||
|
||||
trivial_dof = 0 if grounded else 6
|
||||
jacobian_internal_dof = jacobian_nullity - trivial_dof
|
||||
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
|
||||
is_rigid = jacobian_nullity <= trivial_dof
|
||||
is_minimally_rigid = is_rigid and len(dependent_rows) == 0
|
||||
|
||||
# ---- Per-constraint labels ----
|
||||
# Map Jacobian rows to (joint_id, constraint_index).
|
||||
# Rows are added contiguously per joint in the same order as joints.
|
||||
row_to_joint: list[tuple[int, int]] = []
|
||||
for joint in joints:
|
||||
dof = joint.joint_type.dof
|
||||
for ci in range(dof):
|
||||
row_to_joint.append((joint.joint_id, ci))
|
||||
|
||||
per_constraint: list[ConstraintLabel] = []
|
||||
for edge_idx, edge_result in enumerate(all_edge_results):
|
||||
jid = edge_result["joint_id"]
|
||||
ci = edge_result["constraint_index"]
|
||||
pebble_indep = edge_result["independent"]
|
||||
|
||||
# Find matching Jacobian row
|
||||
jacobian_indep = True
|
||||
if edge_idx < len(row_to_joint):
|
||||
row_idx = edge_idx
|
||||
jacobian_indep = row_idx not in dependent_set
|
||||
|
||||
per_constraint.append(
|
||||
ConstraintLabel(
|
||||
joint_id=jid,
|
||||
constraint_idx=ci,
|
||||
pebble_independent=pebble_indep,
|
||||
jacobian_independent=jacobian_indep,
|
||||
)
|
||||
)
|
||||
|
||||
# ---- Per-joint labels ----
|
||||
joint_agg: dict[int, JointLabel] = {}
|
||||
for cl in per_constraint:
|
||||
if cl.joint_id not in joint_agg:
|
||||
joint_agg[cl.joint_id] = JointLabel(
|
||||
joint_id=cl.joint_id,
|
||||
independent_count=0,
|
||||
redundant_count=0,
|
||||
total=0,
|
||||
)
|
||||
jl = joint_agg[cl.joint_id]
|
||||
jl.total += 1
|
||||
if cl.pebble_independent:
|
||||
jl.independent_count += 1
|
||||
else:
|
||||
jl.redundant_count += 1
|
||||
|
||||
per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg]
|
||||
|
||||
# ---- Per-body DOF labels ----
|
||||
body_ids = [b.body_id for b in bodies]
|
||||
per_body = _compute_per_body_dof(
|
||||
j_reduced,
|
||||
body_ids,
|
||||
ground_body,
|
||||
verifier.body_index,
|
||||
)
|
||||
|
||||
# ---- Assembly label ----
|
||||
assembly_label = AssemblyLabel(
|
||||
classification=classification,
|
||||
total_dof=max(0, jacobian_internal_dof),
|
||||
redundant_count=redundant_count,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
has_degeneracy=geometric_degeneracies > 0,
|
||||
)
|
||||
|
||||
# ---- ConstraintAnalysis (for backward compat) ----
|
||||
analysis = ConstraintAnalysis(
|
||||
combinatorial_dof=effective_dof,
|
||||
combinatorial_internal_dof=effective_internal_dof,
|
||||
combinatorial_redundant=redundant_count,
|
||||
combinatorial_classification=classification,
|
||||
per_edge_results=all_edge_results,
|
||||
jacobian_rank=jacobian_rank,
|
||||
jacobian_nullity=jacobian_nullity,
|
||||
jacobian_internal_dof=max(0, jacobian_internal_dof),
|
||||
numerically_dependent=dependent_rows,
|
||||
geometric_degeneracies=geometric_degeneracies,
|
||||
is_rigid=is_rigid,
|
||||
is_minimally_rigid=is_minimally_rigid,
|
||||
)
|
||||
|
||||
return AssemblyLabels(
|
||||
per_constraint=per_constraint,
|
||||
per_joint=per_joint,
|
||||
per_body=per_body,
|
||||
assembly=assembly_label,
|
||||
analysis=analysis,
|
||||
)
|
||||
258
solver/datagen/pebble_game.py
Normal file
258
solver/datagen/pebble_game.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis.
|
||||
|
||||
Implements the pebble game algorithm adapted for CAD assembly constraint
|
||||
graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints
|
||||
between bodies remove DOF according to their type.
|
||||
|
||||
The pebble game provides a fast combinatorial *necessary* condition for
|
||||
rigidity via Tay's theorem. It does not detect geometric degeneracies —
|
||||
use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient*
|
||||
condition.
|
||||
|
||||
References:
|
||||
- Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008
|
||||
- Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity
|
||||
Percolation: The Pebble Game", J. Comput. Phys., 1997
|
||||
- Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.types import Joint, PebbleState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = ["PebbleGame3D"]
|
||||
|
||||
|
||||
class PebbleGame3D:
|
||||
"""Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks.
|
||||
|
||||
For body-bar-hinge structures in 3D, Tay's theorem states that a
|
||||
multigraph G on n vertices is generically minimally rigid iff:
|
||||
|E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2)
|
||||
|
||||
The (6,6)-pebble game tests this sparsity condition incrementally.
|
||||
Each vertex starts with 6 pebbles (representing 6 DOF). To insert
|
||||
an edge, we need to collect 6+1=7 pebbles on its two endpoints.
|
||||
If we can, the edge is independent (removes a DOF). If not, it's
|
||||
redundant (overconstrained).
|
||||
|
||||
In the CAD assembly context:
|
||||
- Vertices = rigid bodies
|
||||
- Edges = scalar constraints from joints
|
||||
- A revolute joint (5 DOF removed) maps to 5 multigraph edges
|
||||
- A fixed joint (6 DOF removed) maps to 6 multigraph edges
|
||||
"""
|
||||
|
||||
K = 6 # Pebbles per vertex (DOF per rigid body in 3D)
|
||||
L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge
|
||||
|
||||
def __init__(self) -> None:
|
||||
self.state = PebbleState()
|
||||
self._edge_counter = 0
|
||||
self._bodies: set[int] = set()
|
||||
|
||||
def add_body(self, body_id: int) -> None:
|
||||
"""Register a rigid body (vertex) with K=6 free pebbles."""
|
||||
if body_id in self._bodies:
|
||||
return
|
||||
self._bodies.add(body_id)
|
||||
self.state.free_pebbles[body_id] = self.K
|
||||
self.state.incoming[body_id] = set()
|
||||
self.state.outgoing[body_id] = set()
|
||||
|
||||
def add_joint(self, joint: Joint) -> list[dict[str, Any]]:
|
||||
"""Expand a joint into multigraph edges and test each for independence.
|
||||
|
||||
A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph.
|
||||
Each edge is tested individually via the pebble game.
|
||||
|
||||
Returns a list of dicts, one per scalar constraint, with:
|
||||
- edge_id: int
|
||||
- independent: bool
|
||||
- dof_remaining: int (total free pebbles after this edge)
|
||||
"""
|
||||
self.add_body(joint.body_a)
|
||||
self.add_body(joint.body_b)
|
||||
|
||||
num_constraints = joint.joint_type.dof
|
||||
results: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(num_constraints):
|
||||
edge_id = self._edge_counter
|
||||
self._edge_counter += 1
|
||||
|
||||
independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b)
|
||||
total_free = sum(self.state.free_pebbles.values())
|
||||
|
||||
results.append(
|
||||
{
|
||||
"edge_id": edge_id,
|
||||
"joint_id": joint.joint_id,
|
||||
"constraint_index": i,
|
||||
"independent": independent,
|
||||
"dof_remaining": total_free,
|
||||
}
|
||||
)
|
||||
|
||||
return results
|
||||
|
||||
def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool:
|
||||
"""Try to insert a directed edge between u and v.
|
||||
|
||||
The edge is accepted (independent) iff we can collect L+1 = 7
|
||||
pebbles on the two endpoints {u, v} combined.
|
||||
|
||||
If accepted, one pebble is consumed and the edge is directed
|
||||
away from the vertex that gives up the pebble.
|
||||
"""
|
||||
# Count current free pebbles on u and v
|
||||
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||
|
||||
# Try to gather enough pebbles via DFS reachability search
|
||||
if available < self.L + 1:
|
||||
needed = (self.L + 1) - available
|
||||
# Try to free pebbles by searching from u first, then v
|
||||
for target in (u, v):
|
||||
while needed > 0:
|
||||
found = self._search_and_collect(target, frozenset({u, v}))
|
||||
if not found:
|
||||
break
|
||||
needed -= 1
|
||||
|
||||
# Recheck after collection attempts
|
||||
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
|
||||
|
||||
if available >= self.L + 1:
|
||||
# Accept: consume a pebble from whichever endpoint has one
|
||||
source = u if self.state.free_pebbles[u] > 0 else v
|
||||
|
||||
self.state.free_pebbles[source] -= 1
|
||||
self.state.directed_edges[edge_id] = (source, v if source == u else u)
|
||||
self.state.outgoing[source].add((edge_id, v if source == u else u))
|
||||
target = v if source == u else u
|
||||
self.state.incoming[target].add((edge_id, source))
|
||||
self.state.independent_edges.add(edge_id)
|
||||
return True
|
||||
else:
|
||||
# Reject: edge is redundant (overconstrained)
|
||||
self.state.redundant_edges.add(edge_id)
|
||||
return False
|
||||
|
||||
def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool:
|
||||
"""DFS to find a free pebble reachable from *target* and move it.
|
||||
|
||||
Follows directed edges *backwards* (from destination to source)
|
||||
to find a vertex with a free pebble that isn't in *forbidden*.
|
||||
When found, reverses the path to move the pebble to *target*.
|
||||
|
||||
Returns True if a pebble was successfully moved to target.
|
||||
"""
|
||||
# BFS/DFS through the directed graph following outgoing edges
|
||||
# from target. An outgoing edge (target -> w) means target spent
|
||||
# a pebble on that edge. If we can find a vertex with a free
|
||||
# pebble, we reverse edges along the path to move it.
|
||||
|
||||
visited: set[int] = set()
|
||||
# Stack: (current_vertex, path_of_edge_ids_to_reverse)
|
||||
stack: list[tuple[int, list[int]]] = [(target, [])]
|
||||
|
||||
while stack:
|
||||
current, path = stack.pop()
|
||||
if current in visited:
|
||||
continue
|
||||
visited.add(current)
|
||||
|
||||
# Check if current vertex (not in forbidden, not target)
|
||||
# has a free pebble
|
||||
if (
|
||||
current != target
|
||||
and current not in forbidden
|
||||
and self.state.free_pebbles[current] > 0
|
||||
):
|
||||
# Found a pebble — reverse the path
|
||||
self._reverse_path(path, current)
|
||||
return True
|
||||
|
||||
# Follow outgoing edges from current vertex
|
||||
for eid, neighbor in self.state.outgoing.get(current, set()):
|
||||
if neighbor not in visited:
|
||||
stack.append((neighbor, [*path, eid]))
|
||||
|
||||
return False
|
||||
|
||||
def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None:
|
||||
"""Reverse directed edges along a path, moving a pebble to the start.
|
||||
|
||||
The pebble at *pebble_source* is consumed by the last edge in
|
||||
the path, and a pebble is freed at the path's start vertex.
|
||||
"""
|
||||
if not edge_ids:
|
||||
return
|
||||
|
||||
# Reverse each edge in the path
|
||||
for eid in edge_ids:
|
||||
old_source, old_target = self.state.directed_edges[eid]
|
||||
|
||||
# Remove from adjacency
|
||||
self.state.outgoing[old_source].discard((eid, old_target))
|
||||
self.state.incoming[old_target].discard((eid, old_source))
|
||||
|
||||
# Reverse direction
|
||||
self.state.directed_edges[eid] = (old_target, old_source)
|
||||
self.state.outgoing[old_target].add((eid, old_source))
|
||||
self.state.incoming[old_source].add((eid, old_target))
|
||||
|
||||
# Move pebble counts: source loses one, first vertex in path gains one
|
||||
self.state.free_pebbles[pebble_source] -= 1
|
||||
|
||||
# After all reversals, the vertex at the beginning of the
|
||||
# search path gains a pebble
|
||||
_first_src, first_tgt = self.state.directed_edges[edge_ids[0]]
|
||||
self.state.free_pebbles[first_tgt] += 1
|
||||
|
||||
def get_dof(self) -> int:
|
||||
"""Total remaining DOF = sum of free pebbles.
|
||||
|
||||
For a fully rigid assembly, this should be 6 (the trivial rigid
|
||||
body motions of the whole assembly). Internal DOF = total - 6.
|
||||
"""
|
||||
return sum(self.state.free_pebbles.values())
|
||||
|
||||
def get_internal_dof(self) -> int:
|
||||
"""Internal (non-trivial) degrees of freedom."""
|
||||
return max(0, self.get_dof() - 6)
|
||||
|
||||
def is_rigid(self) -> bool:
|
||||
"""Combinatorial rigidity check: rigid iff at most 6 pebbles remain."""
|
||||
return self.get_dof() <= self.L
|
||||
|
||||
def get_redundant_count(self) -> int:
|
||||
"""Number of redundant (overconstrained) scalar constraints."""
|
||||
return len(self.state.redundant_edges)
|
||||
|
||||
def classify_assembly(self, *, grounded: bool = False) -> str:
|
||||
"""Classify the assembly state.
|
||||
|
||||
Args:
|
||||
grounded: If True, the baseline trivial DOF is 0 (not 6),
|
||||
because the ground body's 6 DOF were removed.
|
||||
"""
|
||||
total_dof = self.get_dof()
|
||||
redundant = self.get_redundant_count()
|
||||
baseline = 0 if grounded else self.L
|
||||
|
||||
if redundant > 0 and total_dof > baseline:
|
||||
return "mixed" # Both under and over-constrained regions
|
||||
elif redundant > 0:
|
||||
return "overconstrained"
|
||||
elif total_dof > baseline:
|
||||
return "underconstrained"
|
||||
elif total_dof == baseline:
|
||||
return "well-constrained"
|
||||
else:
|
||||
return "overconstrained"
|
||||
144
solver/datagen/types.py
Normal file
144
solver/datagen/types.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Shared data types for assembly constraint analysis.
|
||||
|
||||
Types ported from the pebble-game synthetic data generator for reuse
|
||||
across the solver package (data generation, training, inference).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"ConstraintAnalysis",
|
||||
"Joint",
|
||||
"JointType",
|
||||
"PebbleState",
|
||||
"RigidBody",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Joint definitions: each joint type removes a known number of DOF
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JointType(enum.Enum):
|
||||
"""Standard CAD joint types with their DOF-removal counts.
|
||||
|
||||
Each joint between two 6-DOF rigid bodies removes a specific number
|
||||
of relative degrees of freedom. In the body-bar-hinge multigraph
|
||||
representation, each joint maps to a number of edges equal to the
|
||||
DOF it removes.
|
||||
|
||||
Values are ``(ordinal, dof_removed)`` tuples so that joint types
|
||||
sharing the same DOF count remain distinct enum members. Use the
|
||||
:attr:`dof` property to get the scalar constraint count.
|
||||
"""
|
||||
|
||||
FIXED = (0, 6) # Locks all relative motion
|
||||
REVOLUTE = (1, 5) # Allows rotation about one axis
|
||||
CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis
|
||||
SLIDER = (3, 5) # Allows translation along one axis (prismatic)
|
||||
BALL = (4, 3) # Allows rotation about a point (spherical)
|
||||
PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane
|
||||
SCREW = (6, 5) # Coupled rotation-translation (helical)
|
||||
UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint)
|
||||
PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints)
|
||||
PERPENDICULAR = (9, 1) # Single angular constraint
|
||||
DISTANCE = (10, 1) # Single scalar distance constraint
|
||||
|
||||
@property
|
||||
def dof(self) -> int:
|
||||
"""Number of scalar constraints (DOF removed) by this joint type."""
|
||||
return self.value[1]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Data structures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class RigidBody:
|
||||
"""A rigid body in the assembly with pose and geometry info."""
|
||||
|
||||
body_id: int
|
||||
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
|
||||
|
||||
# Anchor points for joints, in local frame
|
||||
# Populated when joints reference specific geometry
|
||||
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Joint:
|
||||
"""A joint connecting two rigid bodies."""
|
||||
|
||||
joint_id: int
|
||||
body_a: int # Index of first body
|
||||
body_b: int # Index of second body
|
||||
joint_type: JointType
|
||||
|
||||
# Joint parameters in world frame
|
||||
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
axis: np.ndarray = field(
|
||||
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
|
||||
# For screw joints
|
||||
pitch: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class PebbleState:
|
||||
"""Tracks the state of the pebble game on the multigraph."""
|
||||
|
||||
# Number of free pebbles per body (vertex). Starts at 6.
|
||||
free_pebbles: dict[int, int] = field(default_factory=dict)
|
||||
|
||||
# Directed edges: edge_id -> (source_body, target_body)
|
||||
# Edge is directed away from the body that "spent" a pebble.
|
||||
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
|
||||
|
||||
# Track which edges are independent vs redundant
|
||||
independent_edges: set[int] = field(default_factory=set)
|
||||
redundant_edges: set[int] = field(default_factory=set)
|
||||
|
||||
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
|
||||
# Following directed edges *towards* a body (incoming edges)
|
||||
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
# Outgoing edges from a body
|
||||
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConstraintAnalysis:
|
||||
"""Results of analyzing an assembly's constraint system."""
|
||||
|
||||
# Pebble game (combinatorial) results
|
||||
combinatorial_dof: int
|
||||
combinatorial_internal_dof: int
|
||||
combinatorial_redundant: int
|
||||
combinatorial_classification: str
|
||||
per_edge_results: list[dict[str, Any]]
|
||||
|
||||
# Numerical (Jacobian) results
|
||||
jacobian_rank: int
|
||||
jacobian_nullity: int # = 6n - rank = total DOF
|
||||
jacobian_internal_dof: int # = nullity - 6
|
||||
numerically_dependent: list[int]
|
||||
|
||||
# Combined
|
||||
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
|
||||
is_rigid: bool
|
||||
is_minimally_rigid: bool
|
||||
0
solver/datasets/__init__.py
Normal file
0
solver/datasets/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/evaluation/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
0
solver/inference/__init__.py
Normal file
47
solver/mates/__init__.py
Normal file
47
solver/mates/__init__.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Mate-level constraint types for assembly analysis."""
|
||||
|
||||
from solver.mates.conversion import (
|
||||
MateAnalysisResult,
|
||||
analyze_mate_assembly,
|
||||
convert_mates_to_joints,
|
||||
)
|
||||
from solver.mates.generator import (
|
||||
SyntheticMateGenerator,
|
||||
generate_mate_training_batch,
|
||||
)
|
||||
from solver.mates.labeling import (
|
||||
MateAssemblyLabels,
|
||||
MateLabel,
|
||||
label_mate_assembly,
|
||||
)
|
||||
from solver.mates.patterns import (
|
||||
JointPattern,
|
||||
PatternMatch,
|
||||
recognize_patterns,
|
||||
)
|
||||
from solver.mates.primitives import (
|
||||
GeometryRef,
|
||||
GeometryType,
|
||||
Mate,
|
||||
MateType,
|
||||
dof_removed,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"GeometryRef",
|
||||
"GeometryType",
|
||||
"JointPattern",
|
||||
"Mate",
|
||||
"MateAnalysisResult",
|
||||
"MateAssemblyLabels",
|
||||
"MateLabel",
|
||||
"MateType",
|
||||
"PatternMatch",
|
||||
"SyntheticMateGenerator",
|
||||
"analyze_mate_assembly",
|
||||
"convert_mates_to_joints",
|
||||
"dof_removed",
|
||||
"generate_mate_training_batch",
|
||||
"label_mate_assembly",
|
||||
"recognize_patterns",
|
||||
]
|
||||
276
solver/mates/conversion.py
Normal file
276
solver/mates/conversion.py
Normal file
@@ -0,0 +1,276 @@
|
||||
"""Mate-to-joint conversion and assembly analysis.
|
||||
|
||||
Bridges the mate-level constraint representation to the existing
|
||||
joint-based analysis pipeline. Converts recognized mate patterns
|
||||
to Joint objects, then runs the pebble game and Jacobian analysis,
|
||||
maintaining bidirectional traceability between mates and joints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.labeling import AssemblyLabels, label_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
from solver.mates.patterns import PatternMatch, recognize_patterns
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from solver.mates.primitives import Mate
|
||||
|
||||
__all__ = [
|
||||
"MateAnalysisResult",
|
||||
"analyze_mate_assembly",
|
||||
"convert_mates_to_joints",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Result dataclass
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MateAnalysisResult:
|
||||
"""Combined result of mate-based assembly analysis.
|
||||
|
||||
Attributes:
|
||||
patterns: Recognized joint patterns from mate grouping.
|
||||
joints: Joint objects produced by conversion.
|
||||
mate_to_joint: Mapping from mate_id to list of joint_ids.
|
||||
joint_to_mates: Mapping from joint_id to list of mate_ids.
|
||||
analysis: Constraint analysis from pebble game + Jacobian.
|
||||
labels: Full ground truth labels from label_assembly.
|
||||
"""
|
||||
|
||||
patterns: list[PatternMatch]
|
||||
joints: list[Joint]
|
||||
mate_to_joint: dict[int, list[int]] = field(default_factory=dict)
|
||||
joint_to_mates: dict[int, list[int]] = field(default_factory=dict)
|
||||
analysis: ConstraintAnalysis | None = None
|
||||
labels: AssemblyLabels | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"patterns": [p.to_dict() for p in self.patterns],
|
||||
"joints": [
|
||||
{
|
||||
"joint_id": j.joint_id,
|
||||
"body_a": j.body_a,
|
||||
"body_b": j.body_b,
|
||||
"joint_type": j.joint_type.name,
|
||||
}
|
||||
for j in self.joints
|
||||
],
|
||||
"mate_to_joint": self.mate_to_joint,
|
||||
"joint_to_mates": self.joint_to_mates,
|
||||
"labels": self.labels.to_dict() if self.labels else None,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern-to-JointType mapping
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Maps (JointPattern value) to JointType for known patterns.
|
||||
# Used by convert_mates_to_joints when a full pattern is recognized.
|
||||
_PATTERN_JOINT_MAP: dict[str, JointType] = {
|
||||
"hinge": JointType.REVOLUTE,
|
||||
"slider": JointType.SLIDER,
|
||||
"cylinder": JointType.CYLINDRICAL,
|
||||
"ball": JointType.BALL,
|
||||
"planar": JointType.PLANAR,
|
||||
"fixed": JointType.FIXED,
|
||||
}
|
||||
|
||||
# Fallback mapping for individual mate types when no pattern is recognized.
|
||||
_MATE_JOINT_FALLBACK: dict[str, JointType] = {
|
||||
"COINCIDENT": JointType.PLANAR,
|
||||
"CONCENTRIC": JointType.CYLINDRICAL,
|
||||
"PARALLEL": JointType.PARALLEL,
|
||||
"PERPENDICULAR": JointType.PERPENDICULAR,
|
||||
"TANGENT": JointType.DISTANCE,
|
||||
"DISTANCE": JointType.DISTANCE,
|
||||
"ANGLE": JointType.PERPENDICULAR,
|
||||
"LOCK": JointType.FIXED,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Conversion
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _compute_joint_params(
|
||||
pattern: PatternMatch,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
||||
"""Extract anchor and axis from pattern mates.
|
||||
|
||||
Returns:
|
||||
(anchor_a, anchor_b, axis)
|
||||
"""
|
||||
anchor_a = np.zeros(3)
|
||||
anchor_b = np.zeros(3)
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
|
||||
for mate in pattern.mates:
|
||||
ref_a = mate.ref_a
|
||||
ref_b = mate.ref_b
|
||||
anchor_a = ref_a.origin.copy()
|
||||
anchor_b = ref_b.origin.copy()
|
||||
if ref_a.direction is not None:
|
||||
axis = ref_a.direction.copy()
|
||||
break
|
||||
|
||||
return anchor_a, anchor_b, axis
|
||||
|
||||
|
||||
def _convert_single_mate(
|
||||
mate: Mate,
|
||||
joint_id: int,
|
||||
) -> Joint:
|
||||
"""Convert a single unmatched mate to a Joint."""
|
||||
joint_type = _MATE_JOINT_FALLBACK.get(mate.mate_type.name, JointType.DISTANCE)
|
||||
|
||||
anchor_a = mate.ref_a.origin.copy()
|
||||
anchor_b = mate.ref_b.origin.copy()
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
if mate.ref_a.direction is not None:
|
||||
axis = mate.ref_a.direction.copy()
|
||||
|
||||
return Joint(
|
||||
joint_id=joint_id,
|
||||
body_a=mate.ref_a.body_id,
|
||||
body_b=mate.ref_b.body_id,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor_a,
|
||||
anchor_b=anchor_b,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def convert_mates_to_joints(
|
||||
mates: list[Mate],
|
||||
bodies: list[RigidBody] | None = None,
|
||||
) -> tuple[list[Joint], dict[int, list[int]], dict[int, list[int]]]:
|
||||
"""Convert mates to Joint objects via pattern recognition.
|
||||
|
||||
For each body pair:
|
||||
- If mates form a recognized pattern, emit the equivalent joint.
|
||||
- Otherwise, emit individual joints for each unmatched mate.
|
||||
|
||||
Args:
|
||||
mates: Mate constraints to convert.
|
||||
bodies: Optional body list (unused currently, reserved for
|
||||
future geometry lookups).
|
||||
|
||||
Returns:
|
||||
(joints, mate_to_joint, joint_to_mates) tuple.
|
||||
"""
|
||||
if not mates:
|
||||
return [], {}, {}
|
||||
|
||||
patterns = recognize_patterns(mates)
|
||||
joints: list[Joint] = []
|
||||
mate_to_joint: dict[int, list[int]] = {}
|
||||
joint_to_mates: dict[int, list[int]] = {}
|
||||
|
||||
# Track which mates have been consumed by full-confidence patterns
|
||||
consumed_mate_ids: set[int] = set()
|
||||
next_joint_id = 0
|
||||
|
||||
# First pass: emit joints for full-confidence patterns
|
||||
for pattern in patterns:
|
||||
if pattern.confidence < 1.0:
|
||||
continue
|
||||
if pattern.pattern.value not in _PATTERN_JOINT_MAP:
|
||||
continue
|
||||
|
||||
# Check if any of these mates were already consumed
|
||||
mate_ids = [m.mate_id for m in pattern.mates]
|
||||
if any(mid in consumed_mate_ids for mid in mate_ids):
|
||||
continue
|
||||
|
||||
joint_type = _PATTERN_JOINT_MAP[pattern.pattern.value]
|
||||
anchor_a, anchor_b, axis = _compute_joint_params(pattern)
|
||||
|
||||
joint = Joint(
|
||||
joint_id=next_joint_id,
|
||||
body_a=pattern.body_a,
|
||||
body_b=pattern.body_b,
|
||||
joint_type=joint_type,
|
||||
anchor_a=anchor_a,
|
||||
anchor_b=anchor_b,
|
||||
axis=axis,
|
||||
)
|
||||
joints.append(joint)
|
||||
|
||||
joint_to_mates[next_joint_id] = mate_ids
|
||||
for mid in mate_ids:
|
||||
mate_to_joint.setdefault(mid, []).append(next_joint_id)
|
||||
consumed_mate_ids.add(mid)
|
||||
|
||||
next_joint_id += 1
|
||||
|
||||
# Second pass: emit individual joints for unconsumed mates
|
||||
for mate in mates:
|
||||
if mate.mate_id in consumed_mate_ids:
|
||||
continue
|
||||
|
||||
joint = _convert_single_mate(mate, next_joint_id)
|
||||
joints.append(joint)
|
||||
|
||||
joint_to_mates[next_joint_id] = [mate.mate_id]
|
||||
mate_to_joint.setdefault(mate.mate_id, []).append(next_joint_id)
|
||||
|
||||
next_joint_id += 1
|
||||
|
||||
return joints, mate_to_joint, joint_to_mates
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Full analysis pipeline
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def analyze_mate_assembly(
|
||||
bodies: list[RigidBody],
|
||||
mates: list[Mate],
|
||||
ground_body: int | None = None,
|
||||
) -> MateAnalysisResult:
|
||||
"""Run the full analysis pipeline on a mate-based assembly.
|
||||
|
||||
Orchestrates: recognize_patterns -> convert_mates_to_joints ->
|
||||
label_assembly, returning a combined result with full traceability.
|
||||
|
||||
Args:
|
||||
bodies: Rigid bodies in the assembly.
|
||||
mates: Mate constraints between the bodies.
|
||||
ground_body: If set, this body is fixed to the world.
|
||||
|
||||
Returns:
|
||||
MateAnalysisResult with patterns, joints, mappings, and labels.
|
||||
"""
|
||||
patterns = recognize_patterns(mates)
|
||||
joints, mate_to_joint, joint_to_mates = convert_mates_to_joints(mates, bodies)
|
||||
|
||||
labels = label_assembly(bodies, joints, ground_body)
|
||||
|
||||
return MateAnalysisResult(
|
||||
patterns=patterns,
|
||||
joints=joints,
|
||||
mate_to_joint=mate_to_joint,
|
||||
joint_to_mates=joint_to_mates,
|
||||
analysis=labels.analysis,
|
||||
labels=labels,
|
||||
)
|
||||
315
solver/mates/generator.py
Normal file
315
solver/mates/generator.py
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Mate-based synthetic assembly generator.
|
||||
|
||||
Wraps SyntheticAssemblyGenerator to produce mate-level training data.
|
||||
Generates joint-based assemblies via the existing generator, then
|
||||
reverse-maps joints to plausible mate combinations. Supports noise
|
||||
injection (redundant, missing, incompatible mates) for robust training.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
from solver.mates.conversion import MateAnalysisResult, analyze_mate_assembly
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"SyntheticMateGenerator",
|
||||
"generate_mate_training_batch",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reverse mapping: JointType -> list of (MateType, geom_a, geom_b) combos
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _MateSpec:
|
||||
"""Specification for a mate to generate from a joint."""
|
||||
|
||||
mate_type: MateType
|
||||
geom_a: GeometryType
|
||||
geom_b: GeometryType
|
||||
|
||||
|
||||
_JOINT_TO_MATES: dict[JointType, list[_MateSpec]] = {
|
||||
JointType.REVOLUTE: [
|
||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
],
|
||||
JointType.CYLINDRICAL: [
|
||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
||||
],
|
||||
JointType.BALL: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),
|
||||
],
|
||||
JointType.FIXED: [
|
||||
_MateSpec(MateType.LOCK, GeometryType.FACE, GeometryType.FACE),
|
||||
],
|
||||
JointType.SLIDER: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
_MateSpec(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
|
||||
],
|
||||
JointType.PLANAR: [
|
||||
_MateSpec(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Generator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class SyntheticMateGenerator:
|
||||
"""Generates mate-based assemblies for training data.
|
||||
|
||||
Wraps SyntheticAssemblyGenerator to produce joint-based assemblies,
|
||||
then reverse-maps each joint to a plausible set of mate constraints.
|
||||
|
||||
Args:
|
||||
seed: Random seed for reproducibility.
|
||||
redundant_prob: Probability of injecting a redundant mate per joint.
|
||||
missing_prob: Probability of dropping a mate from a multi-mate pattern.
|
||||
incompatible_prob: Probability of injecting a mate with wrong geometry.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
seed: int = 42,
|
||||
*,
|
||||
redundant_prob: float = 0.0,
|
||||
missing_prob: float = 0.0,
|
||||
incompatible_prob: float = 0.0,
|
||||
) -> None:
|
||||
self._joint_gen = SyntheticAssemblyGenerator(seed=seed)
|
||||
self._rng = np.random.default_rng(seed)
|
||||
self.redundant_prob = redundant_prob
|
||||
self.missing_prob = missing_prob
|
||||
self.incompatible_prob = incompatible_prob
|
||||
|
||||
def _make_geometry_ref(
|
||||
self,
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
joint: Joint,
|
||||
*,
|
||||
is_ref_a: bool = True,
|
||||
) -> GeometryRef:
|
||||
"""Create a GeometryRef from joint geometry.
|
||||
|
||||
Uses joint anchor, axis, and body_id to produce a ref
|
||||
with realistic geometry for the given type.
|
||||
"""
|
||||
origin = joint.anchor_a if is_ref_a else joint.anchor_b
|
||||
|
||||
direction: np.ndarray | None = None
|
||||
if geom_type in {GeometryType.AXIS, GeometryType.PLANE, GeometryType.FACE}:
|
||||
direction = joint.axis.copy()
|
||||
|
||||
geom_id = f"{geom_type.value.capitalize()}001"
|
||||
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id=geom_id,
|
||||
origin=origin.copy(),
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
def _reverse_map_joint(
|
||||
self,
|
||||
joint: Joint,
|
||||
next_mate_id: int,
|
||||
) -> list[Mate]:
|
||||
"""Convert a joint to its mate representation."""
|
||||
specs = _JOINT_TO_MATES.get(joint.joint_type, [])
|
||||
if not specs:
|
||||
# Fallback: emit a single DISTANCE mate
|
||||
specs = [_MateSpec(MateType.DISTANCE, GeometryType.POINT, GeometryType.POINT)]
|
||||
|
||||
mates: list[Mate] = []
|
||||
for spec in specs:
|
||||
ref_a = self._make_geometry_ref(joint.body_a, spec.geom_a, joint, is_ref_a=True)
|
||||
ref_b = self._make_geometry_ref(joint.body_b, spec.geom_b, joint, is_ref_a=False)
|
||||
mates.append(
|
||||
Mate(
|
||||
mate_id=next_mate_id + len(mates),
|
||||
mate_type=spec.mate_type,
|
||||
ref_a=ref_a,
|
||||
ref_b=ref_b,
|
||||
)
|
||||
)
|
||||
return mates
|
||||
|
||||
def _inject_noise(
|
||||
self,
|
||||
mates: list[Mate],
|
||||
next_mate_id: int,
|
||||
) -> list[Mate]:
|
||||
"""Apply noise injection to the mate list.
|
||||
|
||||
Modifies the list in-place and may add new mates.
|
||||
Returns the (possibly extended) list.
|
||||
"""
|
||||
result = list(mates)
|
||||
extra: list[Mate] = []
|
||||
|
||||
for mate in mates:
|
||||
# Redundant: duplicate a mate
|
||||
if self._rng.random() < self.redundant_prob:
|
||||
dup = Mate(
|
||||
mate_id=next_mate_id + len(extra),
|
||||
mate_type=mate.mate_type,
|
||||
ref_a=mate.ref_a,
|
||||
ref_b=mate.ref_b,
|
||||
value=mate.value,
|
||||
tolerance=mate.tolerance,
|
||||
)
|
||||
extra.append(dup)
|
||||
|
||||
# Incompatible: wrong geometry type
|
||||
if self._rng.random() < self.incompatible_prob:
|
||||
bad_geom = GeometryType.POINT
|
||||
bad_ref = GeometryRef(
|
||||
body_id=mate.ref_a.body_id,
|
||||
geometry_type=bad_geom,
|
||||
geometry_id="BadGeom001",
|
||||
origin=mate.ref_a.origin.copy(),
|
||||
direction=None,
|
||||
)
|
||||
extra.append(
|
||||
Mate(
|
||||
mate_id=next_mate_id + len(extra),
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=bad_ref,
|
||||
ref_b=mate.ref_b,
|
||||
)
|
||||
)
|
||||
|
||||
result.extend(extra)
|
||||
|
||||
# Missing: drop mates from multi-mate patterns (only if > 1 mate
|
||||
# for same body pair)
|
||||
if self.missing_prob > 0:
|
||||
filtered: list[Mate] = []
|
||||
for mate in result:
|
||||
if self._rng.random() < self.missing_prob:
|
||||
continue
|
||||
filtered.append(mate)
|
||||
# Ensure at least one mate remains
|
||||
if not filtered and result:
|
||||
filtered = [result[0]]
|
||||
result = filtered
|
||||
|
||||
return result
|
||||
|
||||
def generate(
|
||||
self,
|
||||
n_bodies: int = 4,
|
||||
*,
|
||||
grounded: bool = False,
|
||||
) -> tuple[list[RigidBody], list[Mate], MateAnalysisResult]:
|
||||
"""Generate a mate-based assembly.
|
||||
|
||||
Args:
|
||||
n_bodies: Number of rigid bodies.
|
||||
grounded: Whether to ground the first body.
|
||||
|
||||
Returns:
|
||||
(bodies, mates, analysis_result) tuple.
|
||||
"""
|
||||
bodies, joints, _analysis = self._joint_gen.generate_chain_assembly(
|
||||
n_bodies,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
grounded=grounded,
|
||||
)
|
||||
|
||||
mates: list[Mate] = []
|
||||
next_id = 0
|
||||
for joint in joints:
|
||||
joint_mates = self._reverse_map_joint(joint, next_id)
|
||||
mates.extend(joint_mates)
|
||||
next_id += len(joint_mates)
|
||||
|
||||
# Apply noise
|
||||
mates = self._inject_noise(mates, next_id)
|
||||
|
||||
ground_body = bodies[0].body_id if grounded else None
|
||||
result = analyze_mate_assembly(bodies, mates, ground_body)
|
||||
|
||||
return bodies, mates, result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch generation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def generate_mate_training_batch(
|
||||
batch_size: int = 100,
|
||||
n_bodies_range: tuple[int, int] = (3, 8),
|
||||
seed: int = 42,
|
||||
*,
|
||||
redundant_prob: float = 0.0,
|
||||
missing_prob: float = 0.0,
|
||||
incompatible_prob: float = 0.0,
|
||||
grounded_ratio: float = 1.0,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Produce a batch of mate-level training examples.
|
||||
|
||||
Args:
|
||||
batch_size: Number of assemblies to generate.
|
||||
n_bodies_range: (min, max_exclusive) body count.
|
||||
seed: Random seed.
|
||||
redundant_prob: Probability of redundant mate injection.
|
||||
missing_prob: Probability of missing mate injection.
|
||||
incompatible_prob: Probability of incompatible mate injection.
|
||||
grounded_ratio: Fraction of assemblies that are grounded.
|
||||
|
||||
Returns:
|
||||
List of dicts with bodies, mates, patterns, and labels.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
examples: list[dict[str, Any]] = []
|
||||
|
||||
for i in range(batch_size):
|
||||
gen = SyntheticMateGenerator(
|
||||
seed=seed + i,
|
||||
redundant_prob=redundant_prob,
|
||||
missing_prob=missing_prob,
|
||||
incompatible_prob=incompatible_prob,
|
||||
)
|
||||
n = int(rng.integers(*n_bodies_range))
|
||||
grounded = bool(rng.random() < grounded_ratio)
|
||||
|
||||
bodies, mates, result = gen.generate(n, grounded=grounded)
|
||||
|
||||
examples.append(
|
||||
{
|
||||
"bodies": [
|
||||
{
|
||||
"body_id": b.body_id,
|
||||
"position": b.position.tolist(),
|
||||
}
|
||||
for b in bodies
|
||||
],
|
||||
"mates": [m.to_dict() for m in mates],
|
||||
"patterns": [p.to_dict() for p in result.patterns],
|
||||
"labels": result.labels.to_dict() if result.labels else None,
|
||||
"n_bodies": len(bodies),
|
||||
"n_mates": len(mates),
|
||||
"n_joints": len(result.joints),
|
||||
}
|
||||
)
|
||||
|
||||
return examples
|
||||
224
solver/mates/labeling.py
Normal file
224
solver/mates/labeling.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Mate-level ground truth labels for assembly analysis.
|
||||
|
||||
Back-attributes joint-level independence results to originating mates
|
||||
via the mate-to-joint mapping from conversion.py. Produces per-mate
|
||||
labels indicating whether each mate is independent, redundant, or
|
||||
degenerate.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.mates.conversion import analyze_mate_assembly
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
from solver.datagen.labeling import AssemblyLabel
|
||||
from solver.datagen.types import ConstraintAnalysis, RigidBody
|
||||
from solver.mates.conversion import MateAnalysisResult
|
||||
from solver.mates.patterns import JointPattern, PatternMatch
|
||||
from solver.mates.primitives import Mate
|
||||
|
||||
__all__ = [
|
||||
"MateAssemblyLabels",
|
||||
"MateLabel",
|
||||
"label_mate_assembly",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Label dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class MateLabel:
|
||||
"""Per-mate ground truth label.
|
||||
|
||||
Attributes:
|
||||
mate_id: The mate this label refers to.
|
||||
is_independent: Contributes non-redundant DOF removal.
|
||||
is_redundant: Fully redundant (removable without DOF change).
|
||||
is_degenerate: Combinatorially independent but geometrically dependent.
|
||||
pattern: Which joint pattern this mate belongs to, if any.
|
||||
issue: Detected issue type, if any.
|
||||
"""
|
||||
|
||||
mate_id: int
|
||||
is_independent: bool = True
|
||||
is_redundant: bool = False
|
||||
is_degenerate: bool = False
|
||||
pattern: JointPattern | None = None
|
||||
issue: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"mate_id": self.mate_id,
|
||||
"is_independent": self.is_independent,
|
||||
"is_redundant": self.is_redundant,
|
||||
"is_degenerate": self.is_degenerate,
|
||||
"pattern": self.pattern.value if self.pattern else None,
|
||||
"issue": self.issue,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class MateAssemblyLabels:
|
||||
"""Complete mate-level ground truth labels for an assembly.
|
||||
|
||||
Attributes:
|
||||
per_mate: Per-mate labels.
|
||||
patterns: Recognized joint patterns.
|
||||
assembly: Assembly-wide summary label.
|
||||
analysis: Constraint analysis from pebble game + Jacobian.
|
||||
"""
|
||||
|
||||
per_mate: list[MateLabel]
|
||||
patterns: list[PatternMatch]
|
||||
assembly: AssemblyLabel
|
||||
analysis: ConstraintAnalysis
|
||||
mate_analysis: MateAnalysisResult | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"per_mate": [ml.to_dict() for ml in self.per_mate],
|
||||
"patterns": [p.to_dict() for p in self.patterns],
|
||||
"assembly": {
|
||||
"classification": self.assembly.classification,
|
||||
"total_dof": self.assembly.total_dof,
|
||||
"redundant_count": self.assembly.redundant_count,
|
||||
"is_rigid": self.assembly.is_rigid,
|
||||
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
||||
"has_degeneracy": self.assembly.has_degeneracy,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Labeling logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _build_mate_pattern_map(
|
||||
patterns: list[PatternMatch],
|
||||
) -> dict[int, JointPattern]:
|
||||
"""Map mate_ids to the pattern they belong to (best match)."""
|
||||
result: dict[int, JointPattern] = {}
|
||||
# Sort by confidence descending so best matches win
|
||||
sorted_patterns = sorted(patterns, key=lambda p: -p.confidence)
|
||||
for pm in sorted_patterns:
|
||||
if pm.confidence < 1.0:
|
||||
continue
|
||||
for mate in pm.mates:
|
||||
if mate.mate_id not in result:
|
||||
result[mate.mate_id] = pm.pattern
|
||||
return result
|
||||
|
||||
|
||||
def label_mate_assembly(
|
||||
bodies: list[RigidBody],
|
||||
mates: list[Mate],
|
||||
ground_body: int | None = None,
|
||||
) -> MateAssemblyLabels:
|
||||
"""Produce mate-level ground truth labels for an assembly.
|
||||
|
||||
Runs analyze_mate_assembly() internally, then back-attributes
|
||||
joint-level independence to originating mates via the mate_to_joint
|
||||
mapping.
|
||||
|
||||
A mate is:
|
||||
- **redundant** if ALL joints it contributes to are fully redundant
|
||||
- **degenerate** if any joint it contributes to is geometrically
|
||||
dependent but combinatorially independent
|
||||
- **independent** otherwise
|
||||
|
||||
Args:
|
||||
bodies: Rigid bodies in the assembly.
|
||||
mates: Mate constraints between the bodies.
|
||||
ground_body: If set, this body is fixed to the world.
|
||||
|
||||
Returns:
|
||||
MateAssemblyLabels with per-mate labels and assembly summary.
|
||||
"""
|
||||
mate_result = analyze_mate_assembly(bodies, mates, ground_body)
|
||||
|
||||
# Build per-joint redundancy from labels
|
||||
joint_redundant: dict[int, bool] = {}
|
||||
joint_degenerate: dict[int, bool] = {}
|
||||
|
||||
if mate_result.labels is not None:
|
||||
for jl in mate_result.labels.per_joint:
|
||||
# A joint is fully redundant if all its constraints are redundant
|
||||
joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0
|
||||
# Joint is degenerate if it has more independent constraints
|
||||
# than Jacobian rank would suggest (geometric degeneracy)
|
||||
joint_degenerate[jl.joint_id] = False
|
||||
|
||||
# Check for geometric degeneracy via per-constraint labels
|
||||
for cl in mate_result.labels.per_constraint:
|
||||
if cl.pebble_independent and not cl.jacobian_independent:
|
||||
joint_degenerate[cl.joint_id] = True
|
||||
|
||||
# Build pattern membership map
|
||||
pattern_map = _build_mate_pattern_map(mate_result.patterns)
|
||||
|
||||
# Back-attribute to mates
|
||||
per_mate: list[MateLabel] = []
|
||||
for mate in mates:
|
||||
mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, [])
|
||||
|
||||
if not mate_joint_ids:
|
||||
# Mate wasn't converted to any joint (shouldn't happen, but safe)
|
||||
per_mate.append(
|
||||
MateLabel(
|
||||
mate_id=mate.mate_id,
|
||||
is_independent=False,
|
||||
is_redundant=True,
|
||||
issue="unmapped",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
# Redundant if ALL contributed joints are redundant
|
||||
all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids)
|
||||
|
||||
# Degenerate if ANY contributed joint is degenerate
|
||||
any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids)
|
||||
|
||||
is_independent = not all_redundant
|
||||
pattern = pattern_map.get(mate.mate_id)
|
||||
|
||||
# Determine issue string
|
||||
issue: str | None = None
|
||||
if all_redundant:
|
||||
issue = "redundant"
|
||||
elif any_degenerate:
|
||||
issue = "degenerate"
|
||||
|
||||
per_mate.append(
|
||||
MateLabel(
|
||||
mate_id=mate.mate_id,
|
||||
is_independent=is_independent,
|
||||
is_redundant=all_redundant,
|
||||
is_degenerate=any_degenerate,
|
||||
pattern=pattern,
|
||||
issue=issue,
|
||||
)
|
||||
)
|
||||
|
||||
# Assembly label
|
||||
assert mate_result.labels is not None
|
||||
assembly_label = mate_result.labels.assembly
|
||||
|
||||
return MateAssemblyLabels(
|
||||
per_mate=per_mate,
|
||||
patterns=mate_result.patterns,
|
||||
assembly=assembly_label,
|
||||
analysis=mate_result.labels.analysis,
|
||||
mate_analysis=mate_result,
|
||||
)
|
||||
284
solver/mates/patterns.py
Normal file
284
solver/mates/patterns.py
Normal file
@@ -0,0 +1,284 @@
|
||||
"""Joint pattern recognition from mate combinations.
|
||||
|
||||
Groups mates by body pair and matches them against canonical joint
|
||||
patterns (hinge, slider, ball, etc.). Each pattern is a known
|
||||
combination of mate types that together constrain motion equivalently
|
||||
to a single mechanical joint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.types import JointType
|
||||
from solver.mates.primitives import GeometryType, Mate, MateType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"JointPattern",
|
||||
"PatternMatch",
|
||||
"recognize_patterns",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class JointPattern(enum.Enum):
|
||||
"""Canonical joint patterns formed by mate combinations."""
|
||||
|
||||
HINGE = "hinge"
|
||||
SLIDER = "slider"
|
||||
CYLINDER = "cylinder"
|
||||
BALL = "ball"
|
||||
PLANAR = "planar"
|
||||
FIXED = "fixed"
|
||||
GEAR = "gear"
|
||||
RACK_PINION = "rack_pinion"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern match result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class PatternMatch:
|
||||
"""Result of matching a group of mates to a joint pattern.
|
||||
|
||||
Attributes:
|
||||
pattern: The identified joint pattern.
|
||||
mates: The mates that form this pattern.
|
||||
body_a: First body in the pair.
|
||||
body_b: Second body in the pair.
|
||||
confidence: How well the mates match the canonical pattern (0-1).
|
||||
equivalent_joint_type: The JointType this pattern maps to.
|
||||
missing_mates: Descriptions of mates absent for a full match.
|
||||
"""
|
||||
|
||||
pattern: JointPattern
|
||||
mates: list[Mate]
|
||||
body_a: int
|
||||
body_b: int
|
||||
confidence: float
|
||||
equivalent_joint_type: JointType
|
||||
missing_mates: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"pattern": self.pattern.value,
|
||||
"body_a": self.body_a,
|
||||
"body_b": self.body_b,
|
||||
"confidence": self.confidence,
|
||||
"equivalent_joint_type": self.equivalent_joint_type.name,
|
||||
"mate_ids": [m.mate_id for m in self.mates],
|
||||
"missing_mates": self.missing_mates,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pattern rules (data-driven)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _MateRequirement:
|
||||
"""A single mate requirement within a pattern rule."""
|
||||
|
||||
mate_type: MateType
|
||||
geometry_a: GeometryType | None = None
|
||||
geometry_b: GeometryType | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _PatternRule:
|
||||
"""Defines a canonical pattern as a set of required mates."""
|
||||
|
||||
pattern: JointPattern
|
||||
joint_type: JointType
|
||||
required: tuple[_MateRequirement, ...]
|
||||
description: str = ""
|
||||
|
||||
|
||||
_PATTERN_RULES: list[_PatternRule] = [
|
||||
_PatternRule(
|
||||
pattern=JointPattern.HINGE,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
required=(
|
||||
_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
),
|
||||
description="Concentric axes + coincident plane",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.SLIDER,
|
||||
joint_type=JointType.SLIDER,
|
||||
required=(
|
||||
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
||||
_MateRequirement(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
|
||||
),
|
||||
description="Coincident plane + parallel axis",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.CYLINDER,
|
||||
joint_type=JointType.CYLINDRICAL,
|
||||
required=(_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),),
|
||||
description="Concentric axes only",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.BALL,
|
||||
joint_type=JointType.BALL,
|
||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),),
|
||||
description="Coincident points",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.PLANAR,
|
||||
joint_type=JointType.PLANAR,
|
||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),),
|
||||
description="Coincident faces",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.PLANAR,
|
||||
joint_type=JointType.PLANAR,
|
||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),),
|
||||
description="Coincident planes (alternate planar)",
|
||||
),
|
||||
_PatternRule(
|
||||
pattern=JointPattern.FIXED,
|
||||
joint_type=JointType.FIXED,
|
||||
required=(_MateRequirement(MateType.LOCK),),
|
||||
description="Lock mate",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Matching logic
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _mate_matches_requirement(mate: Mate, req: _MateRequirement) -> bool:
|
||||
"""Check if a mate satisfies a requirement."""
|
||||
if mate.mate_type is not req.mate_type:
|
||||
return False
|
||||
if req.geometry_a is not None and mate.ref_a.geometry_type is not req.geometry_a:
|
||||
return False
|
||||
return not (req.geometry_b is not None and mate.ref_b.geometry_type is not req.geometry_b)
|
||||
|
||||
|
||||
def _try_match_rule(
|
||||
rule: _PatternRule,
|
||||
mates: list[Mate],
|
||||
) -> tuple[float, list[Mate], list[str]]:
|
||||
"""Try to match a rule against a group of mates.
|
||||
|
||||
Returns:
|
||||
(confidence, matched_mates, missing_descriptions)
|
||||
"""
|
||||
matched: list[Mate] = []
|
||||
missing: list[str] = []
|
||||
|
||||
for req in rule.required:
|
||||
found = False
|
||||
for mate in mates:
|
||||
if mate in matched:
|
||||
continue
|
||||
if _mate_matches_requirement(mate, req):
|
||||
matched.append(mate)
|
||||
found = True
|
||||
break
|
||||
if not found:
|
||||
geom_desc = ""
|
||||
if req.geometry_a is not None:
|
||||
geom_b = req.geometry_b.value if req.geometry_b else "*"
|
||||
geom_desc = f" ({req.geometry_a.value}-{geom_b})"
|
||||
missing.append(f"{req.mate_type.name}{geom_desc}")
|
||||
|
||||
total_required = len(rule.required)
|
||||
if total_required == 0:
|
||||
return 0.0, [], []
|
||||
|
||||
matched_count = len(matched)
|
||||
confidence = matched_count / total_required
|
||||
|
||||
return confidence, matched, missing
|
||||
|
||||
|
||||
def _normalize_body_pair(body_a: int, body_b: int) -> tuple[int, int]:
|
||||
"""Normalize a body pair so the smaller ID comes first."""
|
||||
return (min(body_a, body_b), max(body_a, body_b))
|
||||
|
||||
|
||||
def recognize_patterns(mates: list[Mate]) -> list[PatternMatch]:
|
||||
"""Identify joint patterns from a list of mates.
|
||||
|
||||
Groups mates by body pair, then checks each group against
|
||||
canonical pattern rules. Returns matches sorted by confidence
|
||||
descending.
|
||||
|
||||
Args:
|
||||
mates: List of mate constraints to analyze.
|
||||
|
||||
Returns:
|
||||
List of PatternMatch results, highest confidence first.
|
||||
"""
|
||||
if not mates:
|
||||
return []
|
||||
|
||||
# Group mates by normalized body pair
|
||||
groups: dict[tuple[int, int], list[Mate]] = defaultdict(list)
|
||||
for mate in mates:
|
||||
pair = _normalize_body_pair(mate.ref_a.body_id, mate.ref_b.body_id)
|
||||
groups[pair].append(mate)
|
||||
|
||||
results: list[PatternMatch] = []
|
||||
|
||||
for (body_a, body_b), group_mates in groups.items():
|
||||
group_matches: list[PatternMatch] = []
|
||||
|
||||
for rule in _PATTERN_RULES:
|
||||
confidence, matched, missing = _try_match_rule(rule, group_mates)
|
||||
|
||||
if confidence > 0:
|
||||
group_matches.append(
|
||||
PatternMatch(
|
||||
pattern=rule.pattern,
|
||||
mates=matched if matched else group_mates,
|
||||
body_a=body_a,
|
||||
body_b=body_b,
|
||||
confidence=confidence,
|
||||
equivalent_joint_type=rule.joint_type,
|
||||
missing_mates=missing,
|
||||
)
|
||||
)
|
||||
|
||||
if group_matches:
|
||||
# Sort by confidence descending, prefer more-specific patterns
|
||||
group_matches.sort(key=lambda m: (-m.confidence, -len(m.mates)))
|
||||
results.extend(group_matches)
|
||||
else:
|
||||
# No pattern matched at all
|
||||
results.append(
|
||||
PatternMatch(
|
||||
pattern=JointPattern.UNKNOWN,
|
||||
mates=group_mates,
|
||||
body_a=body_a,
|
||||
body_b=body_b,
|
||||
confidence=0.0,
|
||||
equivalent_joint_type=JointType.DISTANCE,
|
||||
missing_mates=[],
|
||||
)
|
||||
)
|
||||
|
||||
# Global sort by confidence descending
|
||||
results.sort(key=lambda m: -m.confidence)
|
||||
return results
|
||||
279
solver/mates/primitives.py
Normal file
279
solver/mates/primitives.py
Normal file
@@ -0,0 +1,279 @@
|
||||
"""Mate type definitions and geometry references for assembly constraints.
|
||||
|
||||
Mates are the user-facing constraint primitives in CAD (e.g. SolidWorks-style
|
||||
Coincident, Concentric, Parallel). Each mate references geometry on two bodies
|
||||
and removes a context-dependent number of degrees of freedom.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"GeometryRef",
|
||||
"GeometryType",
|
||||
"Mate",
|
||||
"MateType",
|
||||
"dof_removed",
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Enums
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MateType(enum.Enum):
|
||||
"""CAD mate types with default DOF-removal counts.
|
||||
|
||||
Values are ``(ordinal, default_dof)`` tuples so that mate types
|
||||
sharing the same DOF count remain distinct enum members. Use the
|
||||
:attr:`default_dof` property to get the scalar constraint count.
|
||||
|
||||
The actual DOF removed can be context-dependent (e.g. COINCIDENT
|
||||
removes 3 DOF for face-face but only 1 for face-point). Use
|
||||
:func:`dof_removed` for the context-aware count.
|
||||
"""
|
||||
|
||||
COINCIDENT = (0, 3)
|
||||
CONCENTRIC = (1, 2)
|
||||
PARALLEL = (2, 2)
|
||||
PERPENDICULAR = (3, 1)
|
||||
TANGENT = (4, 1)
|
||||
DISTANCE = (5, 1)
|
||||
ANGLE = (6, 1)
|
||||
LOCK = (7, 6)
|
||||
|
||||
@property
|
||||
def default_dof(self) -> int:
|
||||
"""Default number of DOF removed by this mate type."""
|
||||
return self.value[1]
|
||||
|
||||
|
||||
class GeometryType(enum.Enum):
|
||||
"""Types of geometric references used by mates."""
|
||||
|
||||
FACE = "face"
|
||||
EDGE = "edge"
|
||||
POINT = "point"
|
||||
AXIS = "axis"
|
||||
PLANE = "plane"
|
||||
|
||||
|
||||
# Geometry types that require a direction vector.
|
||||
_DIRECTIONAL_TYPES = frozenset(
|
||||
{
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataclasses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeometryRef:
|
||||
"""A reference to a specific geometric entity on a body.
|
||||
|
||||
Attributes:
|
||||
body_id: Index of the body this geometry belongs to.
|
||||
geometry_type: What kind of geometry (face, edge, etc.).
|
||||
geometry_id: CAD identifier string (e.g. ``"Face001"``).
|
||||
origin: 3D position of the geometry reference point.
|
||||
direction: Unit direction vector. Required for FACE, AXIS, PLANE;
|
||||
``None`` for POINT.
|
||||
"""
|
||||
|
||||
body_id: int
|
||||
geometry_type: GeometryType
|
||||
geometry_id: str
|
||||
origin: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
||||
direction: np.ndarray | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"body_id": self.body_id,
|
||||
"geometry_type": self.geometry_type.value,
|
||||
"geometry_id": self.geometry_id,
|
||||
"origin": self.origin.tolist(),
|
||||
"direction": self.direction.tolist() if self.direction is not None else None,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> GeometryRef:
|
||||
"""Construct from a dict produced by :meth:`to_dict`."""
|
||||
direction_raw = data.get("direction")
|
||||
return cls(
|
||||
body_id=data["body_id"],
|
||||
geometry_type=GeometryType(data["geometry_type"]),
|
||||
geometry_id=data["geometry_id"],
|
||||
origin=np.asarray(data["origin"], dtype=np.float64),
|
||||
direction=(
|
||||
np.asarray(direction_raw, dtype=np.float64) if direction_raw is not None else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Mate:
|
||||
"""A mate constraint between geometry on two bodies.
|
||||
|
||||
Attributes:
|
||||
mate_id: Unique identifier for this mate.
|
||||
mate_type: The type of constraint (Coincident, Concentric, etc.).
|
||||
ref_a: Geometry reference on the first body.
|
||||
ref_b: Geometry reference on the second body.
|
||||
value: Scalar parameter for DISTANCE and ANGLE mates (0 otherwise).
|
||||
tolerance: Numeric tolerance for constraint satisfaction.
|
||||
"""
|
||||
|
||||
mate_id: int
|
||||
mate_type: MateType
|
||||
ref_a: GeometryRef
|
||||
ref_b: GeometryRef
|
||||
value: float = 0.0
|
||||
tolerance: float = 1e-6
|
||||
|
||||
def validate(self) -> None:
|
||||
"""Raise ``ValueError`` if this mate has incompatible geometry.
|
||||
|
||||
Checks:
|
||||
- Self-mate (both refs on same body)
|
||||
- CONCENTRIC requires AXIS geometry on both refs
|
||||
- PARALLEL requires directional geometry (not POINT)
|
||||
- TANGENT requires surface geometry (FACE or EDGE)
|
||||
- Directional geometry types must have a direction vector
|
||||
"""
|
||||
if self.ref_a.body_id == self.ref_b.body_id:
|
||||
msg = f"Self-mate: ref_a and ref_b both reference body {self.ref_a.body_id}"
|
||||
raise ValueError(msg)
|
||||
|
||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
||||
if ref.geometry_type in _DIRECTIONAL_TYPES and ref.direction is None:
|
||||
msg = (
|
||||
f"{label}: geometry type {ref.geometry_type.value} requires a direction vector"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.mate_type is MateType.CONCENTRIC:
|
||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
||||
if ref.geometry_type is not GeometryType.AXIS:
|
||||
msg = (
|
||||
f"CONCENTRIC mate requires AXIS geometry, "
|
||||
f"got {ref.geometry_type.value} on {label}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.mate_type is MateType.PARALLEL:
|
||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
||||
if ref.geometry_type is GeometryType.POINT:
|
||||
msg = f"PARALLEL mate requires directional geometry, got POINT on {label}"
|
||||
raise ValueError(msg)
|
||||
|
||||
if self.mate_type is MateType.TANGENT:
|
||||
_surface = frozenset({GeometryType.FACE, GeometryType.EDGE})
|
||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
||||
if ref.geometry_type not in _surface:
|
||||
msg = (
|
||||
f"TANGENT mate requires surface geometry "
|
||||
f"(FACE or EDGE), got {ref.geometry_type.value} "
|
||||
f"on {label}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""Return a JSON-serializable dict."""
|
||||
return {
|
||||
"mate_id": self.mate_id,
|
||||
"mate_type": self.mate_type.name,
|
||||
"ref_a": self.ref_a.to_dict(),
|
||||
"ref_b": self.ref_b.to_dict(),
|
||||
"value": self.value,
|
||||
"tolerance": self.tolerance,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> Mate:
|
||||
"""Construct from a dict produced by :meth:`to_dict`."""
|
||||
return cls(
|
||||
mate_id=data["mate_id"],
|
||||
mate_type=MateType[data["mate_type"]],
|
||||
ref_a=GeometryRef.from_dict(data["ref_a"]),
|
||||
ref_b=GeometryRef.from_dict(data["ref_b"]),
|
||||
value=data.get("value", 0.0),
|
||||
tolerance=data.get("tolerance", 1e-6),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Context-dependent DOF removal
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
# Lookup table: (MateType, ref_a GeometryType, ref_b GeometryType) -> DOF removed.
|
||||
# Entries with None match any geometry type for that position.
|
||||
_DOF_TABLE: dict[tuple[MateType, GeometryType | None, GeometryType | None], int] = {
|
||||
# COINCIDENT — context-dependent
|
||||
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE): 3,
|
||||
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT): 3,
|
||||
(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE): 3,
|
||||
(MateType.COINCIDENT, GeometryType.EDGE, GeometryType.EDGE): 2,
|
||||
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.POINT): 1,
|
||||
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.FACE): 1,
|
||||
# CONCENTRIC
|
||||
(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS): 2,
|
||||
# PARALLEL
|
||||
(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS): 2,
|
||||
(MateType.PARALLEL, GeometryType.FACE, GeometryType.FACE): 2,
|
||||
(MateType.PARALLEL, GeometryType.PLANE, GeometryType.PLANE): 2,
|
||||
# TANGENT
|
||||
(MateType.TANGENT, GeometryType.FACE, GeometryType.FACE): 1,
|
||||
(MateType.TANGENT, GeometryType.FACE, GeometryType.EDGE): 1,
|
||||
(MateType.TANGENT, GeometryType.EDGE, GeometryType.FACE): 1,
|
||||
# Types where DOF is always the same regardless of geometry
|
||||
(MateType.PERPENDICULAR, None, None): 1,
|
||||
(MateType.DISTANCE, None, None): 1,
|
||||
(MateType.ANGLE, None, None): 1,
|
||||
(MateType.LOCK, None, None): 6,
|
||||
}
|
||||
|
||||
|
||||
def dof_removed(
|
||||
mate_type: MateType,
|
||||
ref_a: GeometryRef,
|
||||
ref_b: GeometryRef,
|
||||
) -> int:
|
||||
"""Return the number of DOF removed by a mate given its geometry context.
|
||||
|
||||
Looks up the exact ``(mate_type, ref_a.geometry_type, ref_b.geometry_type)``
|
||||
combination first, then falls back to a wildcard ``(mate_type, None, None)``
|
||||
entry, and finally to :attr:`MateType.default_dof`.
|
||||
|
||||
Args:
|
||||
mate_type: The mate constraint type.
|
||||
ref_a: Geometry reference on the first body.
|
||||
ref_b: Geometry reference on the second body.
|
||||
|
||||
Returns:
|
||||
Number of scalar DOF removed by this mate.
|
||||
"""
|
||||
key = (mate_type, ref_a.geometry_type, ref_b.geometry_type)
|
||||
if key in _DOF_TABLE:
|
||||
return _DOF_TABLE[key]
|
||||
|
||||
wildcard = (mate_type, None, None)
|
||||
if wildcard in _DOF_TABLE:
|
||||
return _DOF_TABLE[wildcard]
|
||||
|
||||
return mate_type.default_dof
|
||||
0
solver/models/__init__.py
Normal file
0
solver/models/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
solver/training/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
0
tests/datagen/__init__.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
240
tests/datagen/test_analysis.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""Tests for solver.datagen.analysis -- combined analysis function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _two_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
def _triangle_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([1.0, 1.7, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTwoBodiesRevolute:
|
||||
"""Demo scenario 1: two bodies connected by a revolute joint."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _two_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.jacobian_internal_dof == 1
|
||||
|
||||
def test_not_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert not result.is_rigid
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_no_redundant(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_redundant == 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTwoBodiesFixed:
|
||||
"""Demo scenario 2: two bodies connected by a fixed joint."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _two_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.jacobian_internal_dof == 0
|
||||
|
||||
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_rigid
|
||||
|
||||
def test_minimally_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_minimally_rigid
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification == "well-constrained"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 3: Triangle with revolute joints (overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTriangleRevolute:
|
||||
"""Demo scenario 3: triangle of 3 bodies + 3 revolute joints."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = _triangle_bodies()
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
1,
|
||||
body_a=1,
|
||||
body_b=2,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.5, 0.85, 0.0]),
|
||||
anchor_b=np.array([1.5, 0.85, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
2,
|
||||
body_a=2,
|
||||
body_b=0,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([0.5, 0.85, 0.0]),
|
||||
anchor_b=np.array([0.5, 0.85, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_has_redundant(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_redundant > 0
|
||||
|
||||
def test_classification(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.combinatorial_classification in ("overconstrained", "mixed")
|
||||
|
||||
def test_rigid(self, result: ConstraintAnalysis) -> None:
|
||||
assert result.is_rigid
|
||||
|
||||
def test_numerically_dependent(self, result: ConstraintAnalysis) -> None:
|
||||
assert len(result.numerically_dependent) > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Scenario 4: Parallel revolute axes (geometric degeneracy)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelRevoluteAxes:
|
||||
"""Demo scenario 4: parallel revolute axes create geometric degeneracies."""
|
||||
|
||||
@pytest.fixture()
|
||||
def result(self) -> ConstraintAnalysis:
|
||||
bodies = [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||
]
|
||||
joints = [
|
||||
Joint(
|
||||
0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
Joint(
|
||||
1,
|
||||
body_a=1,
|
||||
body_b=2,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([3.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([3.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
),
|
||||
]
|
||||
return analyze_assembly(bodies, joints, ground_body=0)
|
||||
|
||||
def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None:
|
||||
"""Parallel axes produce at least one geometric degeneracy."""
|
||||
assert result.geometric_degeneracies > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestNoJoints:
|
||||
"""Assembly with bodies but no joints."""
|
||||
|
||||
def test_all_dof_free(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
result = analyze_assembly(bodies, [], ground_body=0)
|
||||
# Body 1 is completely free (6 DOF), body 0 is grounded
|
||||
assert result.jacobian_internal_dof > 0
|
||||
assert not result.is_rigid
|
||||
|
||||
def test_ungrounded(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
result = analyze_assembly(bodies, [])
|
||||
assert result.combinatorial_classification == "underconstrained"
|
||||
|
||||
|
||||
class TestReturnType:
|
||||
"""Verify the return object is a proper ConstraintAnalysis."""
|
||||
|
||||
def test_instance(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
joints = [Joint(0, 0, 1, JointType.FIXED)]
|
||||
result = analyze_assembly(bodies, joints)
|
||||
assert isinstance(result, ConstraintAnalysis)
|
||||
|
||||
def test_per_edge_results_populated(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
joints = [Joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
result = analyze_assembly(bodies, joints)
|
||||
assert len(result.per_edge_results) == 5
|
||||
337
tests/datagen/test_dataset.py
Normal file
337
tests/datagen/test_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
_derive_shard_seed,
|
||||
_parse_scalar,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetConfig:
|
||||
"""DatasetConfig construction and defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
cfg = DatasetConfig()
|
||||
assert cfg.num_assemblies == 100_000
|
||||
assert cfg.seed == 42
|
||||
assert cfg.shard_size == 1000
|
||||
assert cfg.num_workers == 4
|
||||
|
||||
def test_from_dict_flat(self) -> None:
|
||||
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.num_assemblies == 500
|
||||
assert cfg.seed == 123
|
||||
|
||||
def test_from_dict_nested_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 3
|
||||
assert cfg.body_count_max == 20
|
||||
|
||||
def test_from_dict_flat_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 5
|
||||
assert cfg.body_count_max == 30
|
||||
|
||||
def test_from_dict_complexity_distribution(self) -> None:
|
||||
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
|
||||
|
||||
def test_from_dict_templates(self) -> None:
|
||||
d: dict[str, Any] = {"templates": ["chain", "tree"]}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.templates == ["chain", "tree"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseScalar:
|
||||
"""_parse_scalar handles different value types."""
|
||||
|
||||
def test_int(self) -> None:
|
||||
assert _parse_scalar("42") == 42
|
||||
|
||||
def test_float(self) -> None:
|
||||
assert _parse_scalar("3.14") == 3.14
|
||||
|
||||
def test_bool_true(self) -> None:
|
||||
assert _parse_scalar("true") is True
|
||||
|
||||
def test_bool_false(self) -> None:
|
||||
assert _parse_scalar("false") is False
|
||||
|
||||
def test_string(self) -> None:
|
||||
assert _parse_scalar("hello") == "hello"
|
||||
|
||||
def test_inline_comment(self) -> None:
|
||||
assert _parse_scalar("0.4 # some comment") == 0.4
|
||||
|
||||
|
||||
class TestParseSimpleYaml:
|
||||
"""parse_simple_yaml handles the synthetic.yaml format."""
|
||||
|
||||
def test_flat_scalars(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["name"] == "test"
|
||||
assert result["num"] == 42
|
||||
assert result["ratio"] == 0.5
|
||||
|
||||
def test_nested_dict(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["body_count"] == {"min": 2, "max": 50}
|
||||
|
||||
def test_list(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["templates"] == ["chain", "tree", "loop"]
|
||||
|
||||
def test_inline_comments(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["dist"]["simple"] == 0.4
|
||||
|
||||
def test_synthetic_yaml(self) -> None:
|
||||
"""Parse the actual project config."""
|
||||
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
|
||||
assert result["name"] == "synthetic"
|
||||
assert result["num_assemblies"] == 100000
|
||||
assert isinstance(result["complexity_distribution"], dict)
|
||||
assert isinstance(result["templates"], list)
|
||||
assert result["shard_size"] == 1000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShardSeedDerivation:
|
||||
"""_derive_shard_seed is deterministic and unique per shard."""
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 0)
|
||||
assert s1 == s2
|
||||
|
||||
def test_different_shards(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 1)
|
||||
assert s1 != s2
|
||||
|
||||
def test_different_global_seeds(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(99, 0)
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetGenerator — small end-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetGenerator:
|
||||
"""End-to-end tests with small datasets."""
|
||||
|
||||
def test_small_generation(self, tmp_path: Path) -> None:
|
||||
"""Generate 10 examples in a single shard."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
assert shards_dir.exists()
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 1
|
||||
|
||||
index_file = tmp_path / "output" / "index.json"
|
||||
assert index_file.exists()
|
||||
index = json.loads(index_file.read_text())
|
||||
assert index["total_assemblies"] == 10
|
||||
|
||||
stats_file = tmp_path / "output" / "stats.json"
|
||||
assert stats_file.exists()
|
||||
|
||||
def test_multi_shard(self, tmp_path: Path) -> None:
|
||||
"""Generate 20 examples across 2 shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 2
|
||||
|
||||
def test_resume_skips_completed(self, tmp_path: Path) -> None:
|
||||
"""Resume skips already-completed shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Record shard modification times
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
|
||||
|
||||
# Remove stats (simulate incomplete) and re-run
|
||||
(tmp_path / "output" / "stats.json").unlink()
|
||||
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Shards should NOT have been regenerated
|
||||
for p in shards_dir.glob("shard_*"):
|
||||
assert p.stat().st_mtime == mtimes[p.name]
|
||||
|
||||
# Stats should be regenerated
|
||||
assert (tmp_path / "output" / "stats.json").exists()
|
||||
|
||||
def test_checkpoint_removed(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is cleaned up after completion."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
checkpoint = tmp_path / "output" / ".checkpoint.json"
|
||||
assert not checkpoint.exists()
|
||||
|
||||
def test_stats_structure(self, tmp_path: Path) -> None:
|
||||
"""stats.json has expected top-level keys."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
|
||||
assert stats["total_examples"] == 10
|
||||
assert "classification_distribution" in stats
|
||||
assert "body_count_histogram" in stats
|
||||
assert "joint_type_distribution" in stats
|
||||
assert "dof_statistics" in stats
|
||||
assert "geometric_degeneracy" in stats
|
||||
assert "rigidity" in stats
|
||||
|
||||
def test_index_structure(self, tmp_path: Path) -> None:
|
||||
"""index.json has expected format."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=15,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
index = json.loads((tmp_path / "output" / "index.json").read_text())
|
||||
assert index["format_version"] == 1
|
||||
assert index["total_assemblies"] == 15
|
||||
assert index["total_shards"] == 2
|
||||
assert "shards" in index
|
||||
for _name, info in index["shards"].items():
|
||||
assert "start_id" in info
|
||||
assert "count" in info
|
||||
|
||||
def test_deterministic_output(self, tmp_path: Path) -> None:
|
||||
"""Same seed produces same results."""
|
||||
for run_dir in ("run1", "run2"):
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / run_dir),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
|
||||
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
|
||||
assert s1["total_examples"] == s2["total_examples"]
|
||||
assert s1["classification_distribution"] == s2["classification_distribution"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Run the script via subprocess."""
|
||||
|
||||
def test_argparse_mode(self, tmp_path: Path) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/generate_synthetic.py",
|
||||
"--num-assemblies",
|
||||
"5",
|
||||
"--output-dir",
|
||||
str(tmp_path / "cli_out"),
|
||||
"--shard-size",
|
||||
"5",
|
||||
"--num-workers",
|
||||
"1",
|
||||
"--seed",
|
||||
"42",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/developer",
|
||||
timeout=120,
|
||||
env={**os.environ, "PYTHONPATH": "/home/developer"},
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
assert (tmp_path / "cli_out" / "index.json").exists()
|
||||
assert (tmp_path / "cli_out" / "stats.json").exists()
|
||||
682
tests/datagen/test_generator.py
Normal file
682
tests/datagen/test_generator.py
Normal file
@@ -0,0 +1,682 @@
|
||||
"""Tests for solver.datagen.generator -- synthetic assembly generation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
|
||||
from solver.datagen.types import JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Original generators (chain / rigid / overconstrained)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChainAssembly:
|
||||
"""generate_chain_assembly produces valid underconstrained chains."""
|
||||
|
||||
def test_returns_three_tuple(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
bodies, joints, _analysis = gen.generate_chain_assembly(4)
|
||||
assert len(bodies) == 4
|
||||
assert len(joints) == 3
|
||||
|
||||
def test_chain_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_chain_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
bodies, _, _ = gen.generate_chain_assembly(5)
|
||||
ids = [b.body_id for b in bodies]
|
||||
assert ids == [0, 1, 2, 3, 4]
|
||||
|
||||
def test_chain_joint_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(4)
|
||||
for i, j in enumerate(joints):
|
||||
assert j.body_a == i
|
||||
assert j.body_b == i + 1
|
||||
|
||||
def test_chain_custom_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
3,
|
||||
joint_type=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
|
||||
class TestRigidAssembly:
|
||||
"""generate_rigid_assembly produces rigid assemblies."""
|
||||
|
||||
def test_rigid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_rigid_assembly(4)
|
||||
assert analysis.is_rigid
|
||||
|
||||
def test_spanning_tree_structure(self) -> None:
|
||||
"""n bodies should have at least n-1 joints (spanning tree)."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_rigid_assembly(5)
|
||||
assert len(joints) >= len(bodies) - 1
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
"""Same seed produces same results."""
|
||||
g1 = SyntheticAssemblyGenerator(seed=99)
|
||||
g2 = SyntheticAssemblyGenerator(seed=99)
|
||||
_, j1, a1 = g1.generate_rigid_assembly(4)
|
||||
_, j2, a2 = g2.generate_rigid_assembly(4)
|
||||
assert a1.jacobian_rank == a2.jacobian_rank
|
||||
assert len(j1) == len(j2)
|
||||
|
||||
|
||||
class TestOverconstrainedAssembly:
|
||||
"""generate_overconstrained_assembly adds redundant constraints."""
|
||||
|
||||
def test_has_redundant(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=2,
|
||||
)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_extra_joints_added(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_base, _ = gen.generate_rigid_assembly(4)
|
||||
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=3,
|
||||
)
|
||||
# Overconstrained has base joints + extra
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# New topology generators
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTreeAssembly:
|
||||
"""generate_tree_assembly produces tree-structured assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_tree_assembly(6)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(
|
||||
10,
|
||||
branching_factor=2,
|
||||
)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
# With 9 joints and 3 types, very likely to use at least 2
|
||||
assert len(used) >= 2
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(
|
||||
5,
|
||||
joint_types=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(7)
|
||||
assert [b.body_id for b in bodies] == list(range(7))
|
||||
|
||||
|
||||
class TestLoopAssembly:
|
||||
"""generate_loop_assembly produces closed-loop assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_loop_assembly(5)
|
||||
assert len(bodies) == 5
|
||||
assert len(joints) == 5 # n joints for n bodies
|
||||
|
||||
def test_has_redundancy(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_loop_assembly(5)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_wrap_around_connectivity(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(4)
|
||||
edges = {(j.body_a, j.body_b) for j in joints}
|
||||
assert (0, 1) in edges
|
||||
assert (1, 2) in edges
|
||||
assert (2, 3) in edges
|
||||
assert (3, 0) in edges # wrap-around
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 3"):
|
||||
gen.generate_loop_assembly(2)
|
||||
|
||||
def test_mixed_joint_types(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
types = [JointType.REVOLUTE, JointType.FIXED]
|
||||
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
|
||||
used = {j.joint_type for j in joints}
|
||||
assert len(used) >= 2
|
||||
|
||||
|
||||
class TestStarAssembly:
|
||||
"""generate_star_assembly produces hub-and-spoke assemblies."""
|
||||
|
||||
def test_body_and_joint_counts(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_star_assembly(6)
|
||||
assert len(bodies) == 6
|
||||
assert len(joints) == 5 # n - 1
|
||||
|
||||
def test_all_joints_connect_to_hub(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(6)
|
||||
for j in joints:
|
||||
assert j.body_a == 0 or j.body_b == 0
|
||||
|
||||
def test_underconstrained(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_star_assembly(5)
|
||||
assert analysis.combinatorial_classification == "underconstrained"
|
||||
|
||||
def test_minimum_bodies_error(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="at least 2"):
|
||||
gen.generate_star_assembly(1)
|
||||
|
||||
def test_hub_at_origin(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(4)
|
||||
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
|
||||
|
||||
|
||||
class TestMixedAssembly:
|
||||
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
8,
|
||||
edge_density=0.3,
|
||||
)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=0.0,
|
||||
)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=1.5)
|
||||
with pytest.raises(ValueError, match="must be in"):
|
||||
gen.generate_mixed_assembly(5, edge_density=-0.1)
|
||||
|
||||
def test_no_duplicate_edges(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
|
||||
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
|
||||
assert len(edges) == len(set(edges))
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=1.0,
|
||||
)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAxisStrategy:
|
||||
"""Axis sampling strategies produce valid unit vectors."""
|
||||
|
||||
def test_cardinal_axis_from_six(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
|
||||
expected = {
|
||||
(1, 0, 0),
|
||||
(-1, 0, 0),
|
||||
(0, 1, 0),
|
||||
(0, -1, 0),
|
||||
(0, 0, 1),
|
||||
(0, 0, -1),
|
||||
}
|
||||
assert axes == expected
|
||||
|
||||
def test_random_axis_unit_norm(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
for _ in range(50):
|
||||
axis = gen._sample_axis("random")
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
|
||||
def test_near_parallel_close_to_base(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
base = np.array([0.0, 0.0, 1.0])
|
||||
for _ in range(50):
|
||||
axis = gen._near_parallel_axis(base)
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
assert np.dot(axis, base) > 0.95
|
||||
|
||||
def test_sample_axis_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("cardinal")
|
||||
cardinals = [
|
||||
np.array(v, dtype=float)
|
||||
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
|
||||
]
|
||||
assert any(np.allclose(axis, c) for c in cardinals)
|
||||
|
||||
def test_sample_axis_near_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("near_parallel")
|
||||
z = np.array([0.0, 0.0, 1.0])
|
||||
assert np.dot(axis, z) > 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: orientations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRandomOrientations:
|
||||
"""Bodies should have non-identity orientations."""
|
||||
|
||||
def test_bodies_have_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(5)
|
||||
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
|
||||
assert non_identity >= 3
|
||||
|
||||
def test_orientations_are_valid_rotations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(6)
|
||||
for b in bodies:
|
||||
r = b.orientation
|
||||
# R^T R == I
|
||||
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
|
||||
# det(R) == 1
|
||||
assert abs(np.linalg.det(r) - 1.0) < 1e-10
|
||||
|
||||
def test_all_generators_set_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
# Chain
|
||||
bodies, _, _ = gen.generate_chain_assembly(3)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Loop
|
||||
bodies, _, _ = gen.generate_loop_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Mixed
|
||||
bodies, _, _ = gen.generate_mixed_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: grounded parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGroundedParameter:
|
||||
"""Grounded parameter controls ground_body in analysis."""
|
||||
|
||||
def test_chain_grounded_default(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_dof >= 0
|
||||
|
||||
def test_chain_floating(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(
|
||||
4,
|
||||
grounded=False,
|
||||
)
|
||||
# Floating: 6 trivial DOF not subtracted by ground
|
||||
assert analysis.combinatorial_dof >= 6
|
||||
|
||||
def test_floating_vs_grounded_dof_difference(self) -> None:
|
||||
gen1 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
|
||||
# Floating should have higher DOF due to missing ground constraint
|
||||
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gen_method",
|
||||
[
|
||||
"generate_chain_assembly",
|
||||
"generate_rigid_assembly",
|
||||
"generate_tree_assembly",
|
||||
"generate_loop_assembly",
|
||||
"generate_star_assembly",
|
||||
"generate_mixed_assembly",
|
||||
],
|
||||
)
|
||||
def test_all_generators_accept_grounded(
|
||||
self,
|
||||
gen_method: str,
|
||||
) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
method = getattr(gen, gen_method)
|
||||
n = 4
|
||||
# Should not raise
|
||||
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
|
||||
method(n, grounded=False)
|
||||
else:
|
||||
method(n, grounded=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: parallel axis injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelAxisInjection:
|
||||
"""parallel_axis_prob causes shared axis direction."""
|
||||
|
||||
def test_parallel_axes_similar(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
# Near-parallel: |dot| close to 1
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_zero_prob_no_forced_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=0.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
|
||||
# With 5 random axes, extremely unlikely all are parallel
|
||||
assert min(dots) < 0.95
|
||||
|
||||
def test_parallel_on_loop(self) -> None:
|
||||
"""Parallel axes on a loop assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_parallel_on_star(self) -> None:
|
||||
"""Parallel axes on a star assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestComplexityTiers:
|
||||
"""Complexity tier parameter on batch generation."""
|
||||
|
||||
def test_simple_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="simple")
|
||||
lo, hi = COMPLEXITY_RANGES["simple"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_medium_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, complexity_tier="medium")
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_complex_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(3, complexity_tier="complex")
|
||||
lo, hi = COMPLEXITY_RANGES["complex"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
n_bodies_range=(2, 3),
|
||||
complexity_tier="medium",
|
||||
)
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Training batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTrainingBatch:
|
||||
"""generate_training_batch produces well-structured examples."""
|
||||
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"grounded",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"body_orientations",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"labels",
|
||||
"assembly_classification",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"internal_dof",
|
||||
"geometric_degeneracies",
|
||||
}
|
||||
|
||||
VALID_GEN_TYPES: ClassVar[set[str]] = {
|
||||
"chain",
|
||||
"rigid",
|
||||
"overconstrained",
|
||||
"tree",
|
||||
"loop",
|
||||
"star",
|
||||
"mixed",
|
||||
}
|
||||
|
||||
def test_batch_size(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20)
|
||||
assert len(batch) == 20
|
||||
|
||||
def test_example_keys(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert set(ex.keys()) == self.EXPECTED_KEYS
|
||||
|
||||
def test_example_ids_sequential(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(15)
|
||||
assert [ex["example_id"] for ex in batch] == list(range(15))
|
||||
|
||||
def test_generator_type_valid(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(50)
|
||||
for ex in batch:
|
||||
assert ex["generator_type"] in self.VALID_GEN_TYPES
|
||||
|
||||
def test_generator_type_diversity(self) -> None:
|
||||
"""100-sample batch should use at least 5 of 7 generator types."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100)
|
||||
types = {ex["generator_type"] for ex in batch}
|
||||
assert len(types) >= 5
|
||||
|
||||
def test_default_body_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
# default (3, 8), but loop/star may clamp
|
||||
assert 2 <= ex["n_bodies"] <= 7
|
||||
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
for label in ex["joint_labels"].values():
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
assert total == label["total_constraints"]
|
||||
|
||||
def test_body_orientations_present(self) -> None:
|
||||
"""Each example includes body_orientations as 3x3 lists."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
orients = ex["body_orientations"]
|
||||
assert len(orients) == ex["n_bodies"]
|
||||
for o in orients:
|
||||
assert len(o) == 3
|
||||
assert len(o[0]) == 3
|
||||
|
||||
def test_labels_structure(self) -> None:
|
||||
"""Each example has labels dict with expected sub-keys."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
labels = ex["labels"]
|
||||
assert "per_constraint" in labels
|
||||
assert "per_joint" in labels
|
||||
assert "per_body" in labels
|
||||
assert "assembly" in labels
|
||||
|
||||
def test_grounded_field_present(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert isinstance(ex["grounded"], bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch grounded ratio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchGroundedRatio:
|
||||
"""grounded_ratio controls the mix in batch generation."""
|
||||
|
||||
def test_all_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
|
||||
assert all(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_none_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
|
||||
assert not any(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_mixed_ratio(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
|
||||
grounded_count = sum(1 for ex in batch if ex["grounded"])
|
||||
# With 100 samples and p=0.5, should be roughly 50 +/- 20
|
||||
assert 20 < grounded_count < 80
|
||||
|
||||
def test_batch_axis_strategy_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
axis_strategy="cardinal",
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
def test_batch_parallel_axis_prob(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
parallel_axis_prob=0.5,
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed reproducibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSeedReproducibility:
|
||||
"""Different seeds produce different results."""
|
||||
|
||||
def test_different_seeds_differ(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=1)
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
b2 = g2.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
r2 = [ex["is_rigid"] for ex in b2]
|
||||
assert c1 != c2 or r1 != r2
|
||||
|
||||
def test_same_seed_identical(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=123)
|
||||
g2 = SyntheticAssemblyGenerator(seed=123)
|
||||
b1, j1, _ = g1.generate_tree_assembly(5)
|
||||
b2, j2, _ = g2.generate_tree_assembly(5)
|
||||
for a, b in zip(b1, b2, strict=True):
|
||||
np.testing.assert_array_almost_equal(a.position, b.position)
|
||||
assert len(j1) == len(j2)
|
||||
267
tests/datagen/test_jacobian.py
Normal file
267
tests/datagen/test_jacobian.py
Normal file
@@ -0,0 +1,267 @@
|
||||
"""Tests for solver.datagen.jacobian -- Jacobian rank verification."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.jacobian import JacobianVerifier
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _two_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
def _three_bodies() -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJacobianShape:
|
||||
"""Verify Jacobian matrix dimensions for each joint type."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"joint_type,expected_rows",
|
||||
[
|
||||
(JointType.FIXED, 6),
|
||||
(JointType.REVOLUTE, 5),
|
||||
(JointType.CYLINDRICAL, 4),
|
||||
(JointType.SLIDER, 5),
|
||||
(JointType.BALL, 3),
|
||||
(JointType.PLANAR, 3),
|
||||
(JointType.SCREW, 5),
|
||||
(JointType.UNIVERSAL, 4),
|
||||
(JointType.PARALLEL, 3),
|
||||
(JointType.PERPENDICULAR, 1),
|
||||
(JointType.DISTANCE, 1),
|
||||
],
|
||||
)
|
||||
def test_row_count(self, joint_type: JointType, expected_rows: int) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
joint = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=joint_type,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
n_added = v.add_joint_constraints(joint)
|
||||
assert n_added == expected_rows
|
||||
|
||||
j = v.get_jacobian()
|
||||
assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols
|
||||
|
||||
|
||||
class TestNumericalRank:
|
||||
"""Numerical rank checks for known configurations."""
|
||||
|
||||
def test_fixed_joint_rank_six(self) -> None:
|
||||
"""Fixed joint between 2 bodies: rank = 6."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 6
|
||||
|
||||
def test_revolute_joint_rank_five(self) -> None:
|
||||
"""Revolute joint between 2 bodies: rank = 5."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 5
|
||||
|
||||
def test_ball_joint_rank_three(self) -> None:
|
||||
"""Ball joint between 2 bodies: rank = 3."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.BALL,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.numerical_rank() == 3
|
||||
|
||||
def test_empty_jacobian_rank_zero(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
assert v.numerical_rank() == 0
|
||||
|
||||
|
||||
class TestParallelAxesDegeneracy:
|
||||
"""Parallel revolute axes create geometric dependencies."""
|
||||
|
||||
def _four_body_loop(self) -> list[RigidBody]:
|
||||
return [
|
||||
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
|
||||
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
|
||||
RigidBody(2, position=np.array([2.0, 2.0, 0.0])),
|
||||
RigidBody(3, position=np.array([0.0, 2.0, 0.0])),
|
||||
]
|
||||
|
||||
def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]:
|
||||
pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])]
|
||||
return [
|
||||
Joint(
|
||||
joint_id=i,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array(anc, dtype=float),
|
||||
anchor_b=np.array(anc, dtype=float),
|
||||
axis=axes[i],
|
||||
)
|
||||
for i, (a, b, anc) in enumerate(pairs)
|
||||
]
|
||||
|
||||
def test_parallel_has_lower_rank(self) -> None:
|
||||
"""4-body closed loop: all-parallel revolute axes produce lower
|
||||
Jacobian rank than mixed axes due to geometric dependency."""
|
||||
bodies = self._four_body_loop()
|
||||
z_axis = np.array([0.0, 0.0, 1.0])
|
||||
|
||||
# All axes parallel to Z
|
||||
v_par = JacobianVerifier(bodies)
|
||||
for j in self._loop_joints([z_axis] * 4):
|
||||
v_par.add_joint_constraints(j)
|
||||
rank_par = v_par.numerical_rank()
|
||||
|
||||
# Mixed axes
|
||||
mixed = [
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([0.0, 1.0, 0.0]),
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([1.0, 0.0, 0.0]),
|
||||
]
|
||||
v_mix = JacobianVerifier(bodies)
|
||||
for j in self._loop_joints(mixed):
|
||||
v_mix.add_joint_constraints(j)
|
||||
rank_mix = v_mix.numerical_rank()
|
||||
|
||||
assert rank_par < rank_mix
|
||||
|
||||
|
||||
class TestFindDependencies:
|
||||
"""Dependency detection."""
|
||||
|
||||
def test_fixed_joint_no_dependencies(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=0,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert v.find_dependencies() == []
|
||||
|
||||
def test_duplicate_fixed_has_dependencies(self) -> None:
|
||||
"""Two fixed joints on same pair: second is fully dependent."""
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
for jid in range(2):
|
||||
v.add_joint_constraints(
|
||||
Joint(
|
||||
joint_id=jid,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.FIXED,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
)
|
||||
)
|
||||
deps = v.find_dependencies()
|
||||
assert len(deps) == 6 # Second fixed joint entirely redundant
|
||||
|
||||
def test_empty_no_dependencies(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
assert v.find_dependencies() == []
|
||||
|
||||
|
||||
class TestRowLabels:
|
||||
"""Row label metadata."""
|
||||
|
||||
def test_labels_match_rows(self) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
j = Joint(
|
||||
joint_id=7,
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([1.0, 0.0, 0.0]),
|
||||
axis=np.array([0.0, 0.0, 1.0]),
|
||||
)
|
||||
v.add_joint_constraints(j)
|
||||
assert len(v.row_labels) == 5
|
||||
assert all(lab["joint_id"] == 7 for lab in v.row_labels)
|
||||
|
||||
|
||||
class TestPerpendicularPair:
|
||||
"""Internal _perpendicular_pair utility."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"axis",
|
||||
[
|
||||
np.array([1.0, 0.0, 0.0]),
|
||||
np.array([0.0, 1.0, 0.0]),
|
||||
np.array([0.0, 0.0, 1.0]),
|
||||
np.array([1.0, 1.0, 1.0]) / np.sqrt(3),
|
||||
],
|
||||
)
|
||||
def test_orthonormal(self, axis: np.ndarray) -> None:
|
||||
bodies = _two_bodies()
|
||||
v = JacobianVerifier(bodies)
|
||||
t1, t2 = v._perpendicular_pair(axis)
|
||||
|
||||
# All unit length
|
||||
np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12)
|
||||
|
||||
# Mutually perpendicular
|
||||
np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12)
|
||||
np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12)
|
||||
346
tests/datagen/test_labeling.py
Normal file
346
tests/datagen/test_labeling.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.labeling import (
|
||||
label_assembly,
|
||||
)
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
|
||||
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
|
||||
|
||||
|
||||
def _make_joint(
|
||||
jid: int,
|
||||
a: int,
|
||||
b: int,
|
||||
jtype: JointType,
|
||||
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
|
||||
) -> Joint:
|
||||
return Joint(
|
||||
joint_id=jid,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=jtype,
|
||||
anchor_a=np.zeros(3),
|
||||
anchor_b=np.zeros(3),
|
||||
axis=np.array(axis),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-constraint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstraintLabels:
|
||||
"""Per-constraint labels combine pebble game and Jacobian results."""
|
||||
|
||||
def test_fixed_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 6
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_revolute_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 5
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_chain_constraint_count(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 10 # 5 + 5
|
||||
|
||||
def test_constraint_joint_ids(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
|
||||
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
|
||||
assert len(j0_constraints) == 5 # revolute
|
||||
assert len(j1_constraints) == 3 # ball
|
||||
|
||||
def test_overconstrained_has_pebble_redundant(self) -> None:
|
||||
"""Triangle with revolute joints: some constraints redundant."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
|
||||
assert pebble_redundant > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-joint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJointLabels:
|
||||
"""Per-joint aggregated labels."""
|
||||
|
||||
def test_fixed_joint_counts(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_joint) == 1
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.joint_id == 0
|
||||
assert jl.independent_count == 6
|
||||
assert jl.redundant_count == 0
|
||||
assert jl.total == 6
|
||||
|
||||
def test_overconstrained_has_redundant_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
|
||||
assert total_redundant > 0
|
||||
|
||||
def test_joint_total_equals_dof(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.BALL)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.total == 3 # ball has 3 DOF
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-body DOF labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBodyDofLabels:
|
||||
"""Per-body DOF signatures from nullspace projection."""
|
||||
|
||||
def test_fixed_joint_grounded_both_zero(self) -> None:
|
||||
"""Two bodies + fixed joint + grounded: both fully constrained."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert bl.translational_dof == 0
|
||||
assert bl.rotational_dof == 0
|
||||
|
||||
def test_revolute_has_rotational_dof(self) -> None:
|
||||
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
# Revolute allows 1 rotation DOF
|
||||
assert b1.rotational_dof >= 1
|
||||
|
||||
def test_dof_bounds(self) -> None:
|
||||
"""All DOF values should be in [0, 3]."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert 0 <= bl.translational_dof <= 3
|
||||
assert 0 <= bl.rotational_dof <= 3
|
||||
|
||||
def test_floating_more_dof_than_grounded(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
grounded = label_assembly(bodies, joints, ground_body=0)
|
||||
floating = label_assembly(bodies, joints, ground_body=None)
|
||||
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
|
||||
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
|
||||
assert f_total > g_total
|
||||
|
||||
def test_grounded_body_zero_dof(self) -> None:
|
||||
"""The grounded body should have 0 DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b0 = next(b for b in labels.per_body if b.body_id == 0)
|
||||
assert b0.translational_dof == 0
|
||||
assert b0.rotational_dof == 0
|
||||
|
||||
def test_body_count_matches(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_body) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assembly label
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAssemblyLabel:
|
||||
"""Assembly-wide summary labels."""
|
||||
|
||||
def test_underconstrained_chain(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
assert labels.assembly.is_rigid is False
|
||||
|
||||
def test_well_constrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "well-constrained"
|
||||
assert labels.assembly.is_rigid is True
|
||||
assert labels.assembly.is_minimally_rigid is True
|
||||
|
||||
def test_overconstrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.redundant_count > 0
|
||||
|
||||
def test_has_degeneracy_with_parallel_axes(self) -> None:
|
||||
"""Parallel revolute axes in a loop create geometric degeneracy."""
|
||||
z_axis = (0.0, 0.0, 1.0)
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.has_degeneracy is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToDict:
|
||||
"""to_dict produces JSON-serializable output."""
|
||||
|
||||
def test_top_level_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d.keys()) == {
|
||||
"per_constraint",
|
||||
"per_joint",
|
||||
"per_body",
|
||||
"assembly",
|
||||
}
|
||||
|
||||
def test_per_constraint_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
for item in d["per_constraint"]:
|
||||
assert set(item.keys()) == {
|
||||
"joint_id",
|
||||
"constraint_idx",
|
||||
"pebble_independent",
|
||||
"jacobian_independent",
|
||||
}
|
||||
|
||||
def test_assembly_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d["assembly"].keys()) == {
|
||||
"classification",
|
||||
"total_dof",
|
||||
"redundant_count",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"has_degeneracy",
|
||||
}
|
||||
|
||||
def test_json_serializable(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
# Should not raise
|
||||
serialized = json.dumps(d)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelAssemblyEdgeCases:
|
||||
"""Edge cases for label_assembly."""
|
||||
|
||||
def test_no_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=0)
|
||||
assert len(labels.per_constraint) == 0
|
||||
assert len(labels.per_joint) == 0
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
# Non-ground body should be fully free
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
assert b1.translational_dof == 3
|
||||
assert b1.rotational_dof == 3
|
||||
|
||||
def test_no_joints_floating(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=None)
|
||||
assert len(labels.per_body) == 1
|
||||
assert labels.per_body[0].translational_dof == 3
|
||||
assert labels.per_body[0].rotational_dof == 3
|
||||
|
||||
def test_analysis_embedded(self) -> None:
|
||||
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
analysis = labels.analysis
|
||||
assert hasattr(analysis, "combinatorial_classification")
|
||||
assert hasattr(analysis, "jacobian_rank")
|
||||
assert hasattr(analysis, "is_rigid")
|
||||
206
tests/datagen/test_pebble_game.py
Normal file
206
tests/datagen/test_pebble_game.py
Normal file
@@ -0,0 +1,206 @@
|
||||
"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import itertools
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.pebble_game import PebbleGame3D
|
||||
from solver.datagen.types import Joint, JointType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint:
|
||||
"""Shorthand for a revolute joint between bodies *a* and *b*."""
|
||||
if axis is None:
|
||||
axis = np.array([0.0, 0.0, 1.0])
|
||||
return Joint(
|
||||
joint_id=jid,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=JointType.REVOLUTE,
|
||||
axis=axis,
|
||||
)
|
||||
|
||||
|
||||
def _fixed(jid: int, a: int, b: int) -> Joint:
|
||||
return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAddBody:
|
||||
"""Body registration basics."""
|
||||
|
||||
def test_single_body_six_pebbles(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_body(0)
|
||||
assert pg.state.free_pebbles[0] == 6
|
||||
|
||||
def test_duplicate_body_no_op(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_body(0)
|
||||
pg.add_body(0)
|
||||
assert pg.state.free_pebbles[0] == 6
|
||||
|
||||
def test_multiple_bodies(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(5):
|
||||
pg.add_body(i)
|
||||
assert pg.get_dof() == 30 # 5 * 6
|
||||
|
||||
|
||||
class TestAddJoint:
|
||||
"""Joint insertion and DOF accounting."""
|
||||
|
||||
def test_revolute_removes_five_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
assert len(results) == 5 # 5 scalar constraints
|
||||
assert all(r["independent"] for r in results)
|
||||
# 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles
|
||||
assert pg.get_dof() == 7
|
||||
|
||||
def test_fixed_removes_six_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_fixed(0, 0, 1))
|
||||
assert len(results) == 6
|
||||
assert all(r["independent"] for r in results)
|
||||
assert pg.get_dof() == 6
|
||||
|
||||
def test_ball_removes_three_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL)
|
||||
results = pg.add_joint(j)
|
||||
assert len(results) == 3
|
||||
assert all(r["independent"] for r in results)
|
||||
assert pg.get_dof() == 9
|
||||
|
||||
|
||||
class TestTwoBodiesRevolute:
|
||||
"""Two bodies connected by a revolute -- demo scenario 1."""
|
||||
|
||||
def test_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
# Total DOF = 7, internal = 7 - 6 = 1
|
||||
assert pg.get_internal_dof() == 1
|
||||
|
||||
def test_not_rigid(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
assert not pg.is_rigid()
|
||||
|
||||
def test_classification(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
assert pg.classify_assembly() == "underconstrained"
|
||||
|
||||
|
||||
class TestTwoBodiesFixed:
|
||||
"""Two bodies + fixed joint -- demo scenario 2."""
|
||||
|
||||
def test_zero_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.get_internal_dof() == 0
|
||||
|
||||
def test_rigid(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.is_rigid()
|
||||
|
||||
def test_well_constrained(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
assert pg.classify_assembly() == "well-constrained"
|
||||
|
||||
|
||||
class TestTriangleRevolute:
|
||||
"""Triangle of 3 bodies with revolute joints -- demo scenario 3."""
|
||||
|
||||
@pytest.fixture()
|
||||
def pg(self) -> PebbleGame3D:
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_revolute(0, 0, 1))
|
||||
pg.add_joint(_revolute(1, 1, 2))
|
||||
pg.add_joint(_revolute(2, 2, 0))
|
||||
return pg
|
||||
|
||||
def test_has_redundant_edges(self, pg: PebbleGame3D) -> None:
|
||||
assert pg.get_redundant_count() > 0
|
||||
|
||||
def test_classification_overconstrained(self, pg: PebbleGame3D) -> None:
|
||||
# 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed)
|
||||
assert pg.classify_assembly() in ("overconstrained", "mixed")
|
||||
|
||||
def test_rigid(self, pg: PebbleGame3D) -> None:
|
||||
assert pg.is_rigid()
|
||||
|
||||
|
||||
class TestChainNotRigid:
|
||||
"""A serial chain of 4 bodies with revolute joints is never rigid."""
|
||||
|
||||
def test_chain_underconstrained(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(3):
|
||||
pg.add_joint(_revolute(i, i, i + 1))
|
||||
assert not pg.is_rigid()
|
||||
assert pg.classify_assembly() == "underconstrained"
|
||||
|
||||
def test_chain_internal_dof(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
for i in range(3):
|
||||
pg.add_joint(_revolute(i, i, i + 1))
|
||||
# 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3
|
||||
assert pg.get_internal_dof() == 3
|
||||
|
||||
|
||||
class TestEdgeResults:
|
||||
"""Result dicts returned by add_joint."""
|
||||
|
||||
def test_result_keys(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"}
|
||||
for r in results:
|
||||
assert set(r.keys()) == expected_keys
|
||||
|
||||
def test_edge_ids_sequential(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
r1 = pg.add_joint(_revolute(0, 0, 1))
|
||||
r2 = pg.add_joint(_revolute(1, 1, 2))
|
||||
all_ids = [r["edge_id"] for r in r1 + r2]
|
||||
assert all_ids == list(range(10))
|
||||
|
||||
def test_dof_remaining_monotonic(self) -> None:
|
||||
pg = PebbleGame3D()
|
||||
results = pg.add_joint(_revolute(0, 0, 1))
|
||||
dofs = [r["dof_remaining"] for r in results]
|
||||
# Should be non-increasing (each independent edge removes a pebble)
|
||||
for a, b in itertools.pairwise(dofs):
|
||||
assert a >= b
|
||||
|
||||
|
||||
class TestGroundedClassification:
|
||||
"""classify_assembly with grounded=True."""
|
||||
|
||||
def test_grounded_baseline_zero(self) -> None:
|
||||
"""With grounded=True the baseline is 0 (not 6)."""
|
||||
pg = PebbleGame3D()
|
||||
pg.add_joint(_fixed(0, 0, 1))
|
||||
# Ungrounded: well-constrained (6 pebbles = baseline 6)
|
||||
assert pg.classify_assembly(grounded=False) == "well-constrained"
|
||||
# Grounded: the 6 remaining pebbles on body 1 exceed baseline 0,
|
||||
# so the raw pebble game (without a virtual ground body) sees this
|
||||
# as underconstrained. The analysis function handles this properly
|
||||
# by adding a virtual ground body.
|
||||
assert pg.classify_assembly(grounded=True) == "underconstrained"
|
||||
163
tests/datagen/test_types.py
Normal file
163
tests/datagen/test_types.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""Tests for solver.datagen.types -- shared data types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.datagen.types import (
|
||||
ConstraintAnalysis,
|
||||
Joint,
|
||||
JointType,
|
||||
PebbleState,
|
||||
RigidBody,
|
||||
)
|
||||
|
||||
|
||||
class TestJointType:
|
||||
"""JointType enum construction and DOF values."""
|
||||
|
||||
EXPECTED_DOF: ClassVar[dict[str, int]] = {
|
||||
"FIXED": 6,
|
||||
"REVOLUTE": 5,
|
||||
"CYLINDRICAL": 4,
|
||||
"SLIDER": 5,
|
||||
"BALL": 3,
|
||||
"PLANAR": 3,
|
||||
"SCREW": 5,
|
||||
"UNIVERSAL": 4,
|
||||
"PARALLEL": 3,
|
||||
"PERPENDICULAR": 1,
|
||||
"DISTANCE": 1,
|
||||
}
|
||||
|
||||
def test_member_count(self) -> None:
|
||||
assert len(JointType) == 11
|
||||
|
||||
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
|
||||
def test_dof_values(self, name: str, dof: int) -> None:
|
||||
assert JointType[name].dof == dof
|
||||
|
||||
def test_access_by_name(self) -> None:
|
||||
assert JointType["REVOLUTE"] is JointType.REVOLUTE
|
||||
|
||||
def test_value_is_tuple(self) -> None:
|
||||
assert JointType.REVOLUTE.value == (1, 5)
|
||||
assert JointType.REVOLUTE.dof == 5
|
||||
|
||||
|
||||
class TestRigidBody:
|
||||
"""RigidBody dataclass defaults and construction."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
body = RigidBody(body_id=0)
|
||||
np.testing.assert_array_equal(body.position, np.zeros(3))
|
||||
np.testing.assert_array_equal(body.orientation, np.eye(3))
|
||||
assert body.local_anchors == {}
|
||||
|
||||
def test_custom_position(self) -> None:
|
||||
pos = np.array([1.0, 2.0, 3.0])
|
||||
body = RigidBody(body_id=7, position=pos)
|
||||
np.testing.assert_array_equal(body.position, pos)
|
||||
assert body.body_id == 7
|
||||
|
||||
def test_local_anchors_mutable(self) -> None:
|
||||
body = RigidBody(body_id=0)
|
||||
body.local_anchors["top"] = np.array([0.0, 0.0, 1.0])
|
||||
assert "top" in body.local_anchors
|
||||
|
||||
def test_default_factory_isolation(self) -> None:
|
||||
"""Each instance gets its own default containers."""
|
||||
b1 = RigidBody(body_id=0)
|
||||
b2 = RigidBody(body_id=1)
|
||||
b1.local_anchors["x"] = np.zeros(3)
|
||||
assert "x" not in b2.local_anchors
|
||||
|
||||
|
||||
class TestJoint:
|
||||
"""Joint dataclass defaults and construction."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE)
|
||||
np.testing.assert_array_equal(j.anchor_a, np.zeros(3))
|
||||
np.testing.assert_array_equal(j.anchor_b, np.zeros(3))
|
||||
np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0]))
|
||||
assert j.pitch == 0.0
|
||||
|
||||
def test_full_construction(self) -> None:
|
||||
j = Joint(
|
||||
joint_id=5,
|
||||
body_a=2,
|
||||
body_b=3,
|
||||
joint_type=JointType.SCREW,
|
||||
anchor_a=np.array([1.0, 0.0, 0.0]),
|
||||
anchor_b=np.array([2.0, 0.0, 0.0]),
|
||||
axis=np.array([1.0, 0.0, 0.0]),
|
||||
pitch=0.5,
|
||||
)
|
||||
assert j.joint_id == 5
|
||||
assert j.joint_type is JointType.SCREW
|
||||
assert j.pitch == 0.5
|
||||
|
||||
|
||||
class TestPebbleState:
|
||||
"""PebbleState dataclass defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
s = PebbleState()
|
||||
assert s.free_pebbles == {}
|
||||
assert s.directed_edges == {}
|
||||
assert s.independent_edges == set()
|
||||
assert s.redundant_edges == set()
|
||||
assert s.incoming == {}
|
||||
assert s.outgoing == {}
|
||||
|
||||
def test_default_factory_isolation(self) -> None:
|
||||
s1 = PebbleState()
|
||||
s2 = PebbleState()
|
||||
s1.free_pebbles[0] = 6
|
||||
assert 0 not in s2.free_pebbles
|
||||
|
||||
|
||||
class TestConstraintAnalysis:
|
||||
"""ConstraintAnalysis dataclass construction."""
|
||||
|
||||
def test_construction(self) -> None:
|
||||
ca = ConstraintAnalysis(
|
||||
combinatorial_dof=6,
|
||||
combinatorial_internal_dof=0,
|
||||
combinatorial_redundant=0,
|
||||
combinatorial_classification="well-constrained",
|
||||
per_edge_results=[],
|
||||
jacobian_rank=6,
|
||||
jacobian_nullity=0,
|
||||
jacobian_internal_dof=0,
|
||||
numerically_dependent=[],
|
||||
geometric_degeneracies=0,
|
||||
is_rigid=True,
|
||||
is_minimally_rigid=True,
|
||||
)
|
||||
assert ca.is_rigid is True
|
||||
assert ca.is_minimally_rigid is True
|
||||
assert ca.combinatorial_classification == "well-constrained"
|
||||
|
||||
def test_per_edge_results_typing(self) -> None:
|
||||
"""per_edge_results accepts list[dict[str, Any]]."""
|
||||
ca = ConstraintAnalysis(
|
||||
combinatorial_dof=7,
|
||||
combinatorial_internal_dof=1,
|
||||
combinatorial_redundant=0,
|
||||
combinatorial_classification="underconstrained",
|
||||
per_edge_results=[{"edge_id": 0, "independent": True}],
|
||||
jacobian_rank=5,
|
||||
jacobian_nullity=1,
|
||||
jacobian_internal_dof=1,
|
||||
numerically_dependent=[],
|
||||
geometric_degeneracies=0,
|
||||
is_rigid=False,
|
||||
is_minimally_rigid=False,
|
||||
)
|
||||
assert len(ca.per_edge_results) == 1
|
||||
assert ca.per_edge_results[0]["edge_id"] == 0
|
||||
0
tests/mates/__init__.py
Normal file
0
tests/mates/__init__.py
Normal file
287
tests/mates/test_conversion.py
Normal file
287
tests/mates/test_conversion.py
Normal file
@@ -0,0 +1,287 @@
|
||||
"""Tests for solver.mates.conversion -- mate-to-joint conversion."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import JointType, RigidBody
|
||||
from solver.mates.conversion import (
|
||||
MateAnalysisResult,
|
||||
analyze_mate_assembly,
|
||||
convert_mates_to_joints,
|
||||
)
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ref(
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
*,
|
||||
origin: np.ndarray | None = None,
|
||||
direction: np.ndarray | None = None,
|
||||
) -> GeometryRef:
|
||||
"""Factory for GeometryRef with sensible defaults."""
|
||||
if origin is None:
|
||||
origin = np.zeros(3)
|
||||
if direction is None and geom_type in {
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}:
|
||||
direction = np.array([0.0, 0.0, 1.0])
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id="Geom001",
|
||||
origin=origin,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
def _make_bodies(n: int) -> list[RigidBody]:
|
||||
"""Create n bodies at distinct positions."""
|
||||
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# convert_mates_to_joints
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConvertMatesToJoints:
|
||||
"""convert_mates_to_joints function."""
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
joints, m2j, j2m = convert_mates_to_joints([])
|
||||
assert joints == []
|
||||
assert m2j == {}
|
||||
assert j2m == {}
|
||||
|
||||
def test_hinge_pattern(self) -> None:
|
||||
"""Concentric + Coincident(plane) -> single REVOLUTE joint."""
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
joints, m2j, j2m = convert_mates_to_joints(mates)
|
||||
assert len(joints) == 1
|
||||
assert joints[0].joint_type is JointType.REVOLUTE
|
||||
assert joints[0].body_a == 0
|
||||
assert joints[0].body_b == 1
|
||||
# Both mates map to the single joint
|
||||
assert 0 in m2j
|
||||
assert 1 in m2j
|
||||
assert j2m[joints[0].joint_id] == [0, 1]
|
||||
|
||||
def test_lock_pattern(self) -> None:
|
||||
"""Lock -> FIXED joint."""
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
joints, _m2j, _j2m = convert_mates_to_joints(mates)
|
||||
assert len(joints) == 1
|
||||
assert joints[0].joint_type is JointType.FIXED
|
||||
|
||||
def test_unmatched_mate_fallback(self) -> None:
|
||||
"""A single ANGLE mate with no pattern -> individual joint."""
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.ANGLE,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
joints, _m2j, _j2m = convert_mates_to_joints(mates)
|
||||
assert len(joints) == 1
|
||||
assert joints[0].joint_type is JointType.PERPENDICULAR
|
||||
|
||||
def test_mapping_consistency(self) -> None:
|
||||
"""mate_to_joint and joint_to_mates are consistent."""
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
Mate(
|
||||
mate_id=2,
|
||||
mate_type=MateType.DISTANCE,
|
||||
ref_a=_make_ref(2, GeometryType.POINT),
|
||||
ref_b=_make_ref(3, GeometryType.POINT),
|
||||
),
|
||||
]
|
||||
joints, m2j, j2m = convert_mates_to_joints(mates)
|
||||
# Every mate should be in m2j
|
||||
for mate in mates:
|
||||
assert mate.mate_id in m2j
|
||||
# Every joint should be in j2m
|
||||
for joint in joints:
|
||||
assert joint.joint_id in j2m
|
||||
|
||||
def test_joint_axis_from_geometry(self) -> None:
|
||||
"""Joint axis should come from mate geometry direction."""
|
||||
axis_dir = np.array([1.0, 0.0, 0.0])
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS, direction=axis_dir),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS, direction=axis_dir),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
joints, _, _ = convert_mates_to_joints(mates)
|
||||
np.testing.assert_array_almost_equal(joints[0].axis, axis_dir)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateAnalysisResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateAnalysisResult:
|
||||
"""MateAnalysisResult dataclass."""
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
result = MateAnalysisResult(
|
||||
patterns=[],
|
||||
joints=[],
|
||||
)
|
||||
d = result.to_dict()
|
||||
assert d["patterns"] == []
|
||||
assert d["joints"] == []
|
||||
assert d["labels"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# analyze_mate_assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAnalyzeMateAssembly:
|
||||
"""Full pipeline: mates -> joints -> analysis."""
|
||||
|
||||
def test_two_bodies_hinge(self) -> None:
|
||||
"""Two bodies connected by hinge mates -> underconstrained (1 DOF)."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
result = analyze_mate_assembly(bodies, mates)
|
||||
assert result.analysis is not None
|
||||
assert result.labels is not None
|
||||
# A revolute joint removes 5 DOF, leaving 1 internal DOF
|
||||
assert result.analysis.combinatorial_internal_dof == 1
|
||||
assert len(result.joints) == 1
|
||||
assert result.joints[0].joint_type is JointType.REVOLUTE
|
||||
|
||||
def test_two_bodies_fixed(self) -> None:
|
||||
"""Two bodies with lock mate -> well-constrained."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = analyze_mate_assembly(bodies, mates)
|
||||
assert result.analysis is not None
|
||||
assert result.analysis.combinatorial_internal_dof == 0
|
||||
assert result.analysis.is_rigid
|
||||
|
||||
def test_grounded_assembly(self) -> None:
|
||||
"""Grounded assembly analysis works."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = analyze_mate_assembly(bodies, mates, ground_body=0)
|
||||
assert result.analysis is not None
|
||||
assert result.analysis.is_rigid
|
||||
|
||||
def test_no_mates(self) -> None:
|
||||
"""Assembly with no mates should be fully underconstrained."""
|
||||
bodies = _make_bodies(2)
|
||||
result = analyze_mate_assembly(bodies, [])
|
||||
assert result.analysis is not None
|
||||
assert result.analysis.combinatorial_internal_dof == 6
|
||||
assert len(result.joints) == 0
|
||||
|
||||
def test_single_body(self) -> None:
|
||||
"""Single body, no mates."""
|
||||
bodies = _make_bodies(1)
|
||||
result = analyze_mate_assembly(bodies, [])
|
||||
assert result.analysis is not None
|
||||
assert len(result.joints) == 0
|
||||
|
||||
def test_result_traceability(self) -> None:
|
||||
"""mate_to_joint and joint_to_mates populated in result."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
result = analyze_mate_assembly(bodies, mates)
|
||||
assert 0 in result.mate_to_joint
|
||||
assert 1 in result.mate_to_joint
|
||||
assert len(result.joint_to_mates) > 0
|
||||
155
tests/mates/test_generator.py
Normal file
155
tests/mates/test_generator.py
Normal file
@@ -0,0 +1,155 @@
|
||||
"""Tests for solver.mates.generator -- synthetic mate generator."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch
|
||||
from solver.mates.primitives import MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SyntheticMateGenerator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSyntheticMateGenerator:
|
||||
"""SyntheticMateGenerator core functionality."""
|
||||
|
||||
def test_generate_basic(self) -> None:
|
||||
"""Generate a simple assembly with mates."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
bodies, mates, result = gen.generate(3)
|
||||
assert len(bodies) == 3
|
||||
assert len(mates) > 0
|
||||
assert result.analysis is not None
|
||||
|
||||
def test_deterministic_with_seed(self) -> None:
|
||||
"""Same seed produces same output."""
|
||||
gen1 = SyntheticMateGenerator(seed=123)
|
||||
_, mates1, _ = gen1.generate(3)
|
||||
|
||||
gen2 = SyntheticMateGenerator(seed=123)
|
||||
_, mates2, _ = gen2.generate(3)
|
||||
|
||||
assert len(mates1) == len(mates2)
|
||||
for m1, m2 in zip(mates1, mates2, strict=True):
|
||||
assert m1.mate_type == m2.mate_type
|
||||
assert m1.ref_a.body_id == m2.ref_a.body_id
|
||||
|
||||
def test_grounded(self) -> None:
|
||||
"""Grounded assembly should work."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
bodies, _mates, result = gen.generate(3, grounded=True)
|
||||
assert len(bodies) == 3
|
||||
assert result.analysis is not None
|
||||
|
||||
def test_revolute_produces_two_mates(self) -> None:
|
||||
"""A revolute joint should reverse-map to 2 mates."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, mates, _result = gen.generate(2)
|
||||
# 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident)
|
||||
assert len(mates) == 2
|
||||
mate_types = {m.mate_type for m in mates}
|
||||
assert MateType.CONCENTRIC in mate_types
|
||||
assert MateType.COINCIDENT in mate_types
|
||||
|
||||
|
||||
class TestReverseMapping:
|
||||
"""Reverse mapping from joints to mates."""
|
||||
|
||||
def test_revolute_mapping(self) -> None:
|
||||
"""REVOLUTE -> Concentric + Coincident."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, mates, _result = gen.generate(2)
|
||||
types = [m.mate_type for m in mates]
|
||||
assert MateType.CONCENTRIC in types
|
||||
assert MateType.COINCIDENT in types
|
||||
|
||||
def test_round_trip_analysis(self) -> None:
|
||||
"""Generated mates round-trip through analysis successfully."""
|
||||
gen = SyntheticMateGenerator(seed=42)
|
||||
_bodies, _mates, result = gen.generate(4)
|
||||
assert result.analysis is not None
|
||||
assert result.labels is not None
|
||||
# Should produce joints from the mates
|
||||
assert len(result.joints) > 0
|
||||
|
||||
|
||||
class TestNoiseInjection:
|
||||
"""Noise injection mechanisms."""
|
||||
|
||||
def test_redundant_injection(self) -> None:
|
||||
"""Redundant prob > 0 produces more mates than clean version."""
|
||||
gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0)
|
||||
_, mates_clean, _ = gen_clean.generate(4)
|
||||
|
||||
gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0)
|
||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
||||
|
||||
assert len(mates_noisy) > len(mates_clean)
|
||||
|
||||
def test_missing_injection(self) -> None:
|
||||
"""Missing prob > 0 produces fewer mates than clean version."""
|
||||
gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0)
|
||||
_, mates_clean, _ = gen_clean.generate(4)
|
||||
|
||||
gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5)
|
||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
||||
|
||||
# With 50% drop rate on 6 mates, very likely to drop at least one
|
||||
assert len(mates_noisy) <= len(mates_clean)
|
||||
|
||||
def test_incompatible_injection(self) -> None:
|
||||
"""Incompatible prob > 0 adds mates with wrong geometry."""
|
||||
gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0)
|
||||
_, mates, _ = gen.generate(3)
|
||||
# Should have extra mates beyond the clean count
|
||||
gen_clean = SyntheticMateGenerator(seed=42)
|
||||
_, mates_clean, _ = gen_clean.generate(3)
|
||||
assert len(mates) > len(mates_clean)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# generate_mate_training_batch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGenerateMateTrainingBatch:
|
||||
"""Batch generation function."""
|
||||
|
||||
def test_batch_structure(self) -> None:
|
||||
"""Each example has required keys."""
|
||||
examples = generate_mate_training_batch(batch_size=3, seed=42)
|
||||
assert len(examples) == 3
|
||||
for ex in examples:
|
||||
assert "bodies" in ex
|
||||
assert "mates" in ex
|
||||
assert "patterns" in ex
|
||||
assert "labels" in ex
|
||||
assert "n_bodies" in ex
|
||||
assert "n_mates" in ex
|
||||
assert "n_joints" in ex
|
||||
|
||||
def test_batch_deterministic(self) -> None:
|
||||
"""Same seed produces same batch."""
|
||||
batch1 = generate_mate_training_batch(batch_size=5, seed=99)
|
||||
batch2 = generate_mate_training_batch(batch_size=5, seed=99)
|
||||
for ex1, ex2 in zip(batch1, batch2, strict=True):
|
||||
assert ex1["n_bodies"] == ex2["n_bodies"]
|
||||
assert ex1["n_mates"] == ex2["n_mates"]
|
||||
|
||||
def test_batch_grounded_ratio(self) -> None:
|
||||
"""Batch respects grounded_ratio parameter."""
|
||||
# All grounded
|
||||
examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0)
|
||||
assert len(examples) == 5
|
||||
|
||||
def test_batch_with_noise(self) -> None:
|
||||
"""Batch with noise injection runs without error."""
|
||||
examples = generate_mate_training_batch(
|
||||
batch_size=3,
|
||||
seed=42,
|
||||
redundant_prob=0.3,
|
||||
missing_prob=0.1,
|
||||
)
|
||||
assert len(examples) == 3
|
||||
for ex in examples:
|
||||
assert ex["n_mates"] >= 0
|
||||
224
tests/mates/test_labeling.py
Normal file
224
tests/mates/test_labeling.py
Normal file
@@ -0,0 +1,224 @@
|
||||
"""Tests for solver.mates.labeling -- mate-level ground truth labels."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import RigidBody
|
||||
from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly
|
||||
from solver.mates.patterns import JointPattern
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ref(
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
*,
|
||||
origin: np.ndarray | None = None,
|
||||
direction: np.ndarray | None = None,
|
||||
) -> GeometryRef:
|
||||
"""Factory for GeometryRef with sensible defaults."""
|
||||
if origin is None:
|
||||
origin = np.zeros(3)
|
||||
if direction is None and geom_type in {
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}:
|
||||
direction = np.array([0.0, 0.0, 1.0])
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id="Geom001",
|
||||
origin=origin,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
def _make_bodies(n: int) -> list[RigidBody]:
|
||||
"""Create n bodies at distinct positions."""
|
||||
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateLabel
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateLabel:
|
||||
"""MateLabel dataclass."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
ml = MateLabel(mate_id=0)
|
||||
assert ml.is_independent is True
|
||||
assert ml.is_redundant is False
|
||||
assert ml.is_degenerate is False
|
||||
assert ml.pattern is None
|
||||
assert ml.issue is None
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
ml = MateLabel(
|
||||
mate_id=5,
|
||||
is_independent=False,
|
||||
is_redundant=True,
|
||||
pattern=JointPattern.HINGE,
|
||||
issue="redundant",
|
||||
)
|
||||
d = ml.to_dict()
|
||||
assert d["mate_id"] == 5
|
||||
assert d["is_redundant"] is True
|
||||
assert d["pattern"] == "hinge"
|
||||
assert d["issue"] == "redundant"
|
||||
|
||||
def test_to_dict_none_pattern(self) -> None:
|
||||
ml = MateLabel(mate_id=0)
|
||||
d = ml.to_dict()
|
||||
assert d["pattern"] is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateAssemblyLabels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateAssemblyLabels:
|
||||
"""MateAssemblyLabels dataclass."""
|
||||
|
||||
def test_to_dict_structure(self) -> None:
|
||||
"""to_dict produces expected keys."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
d = result.to_dict()
|
||||
assert "per_mate" in d
|
||||
assert "patterns" in d
|
||||
assert "assembly" in d
|
||||
assert isinstance(d["per_mate"], list)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# label_mate_assembly
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelMateAssembly:
|
||||
"""Full labeling pipeline."""
|
||||
|
||||
def test_clean_assembly_no_redundancy(self) -> None:
|
||||
"""Two bodies with lock mate -> clean, no redundancy."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert isinstance(result, MateAssemblyLabels)
|
||||
assert len(result.per_mate) == 1
|
||||
ml = result.per_mate[0]
|
||||
assert ml.mate_id == 0
|
||||
assert ml.is_independent is True
|
||||
assert ml.is_redundant is False
|
||||
assert ml.issue is None
|
||||
|
||||
def test_redundant_assembly(self) -> None:
|
||||
"""Two lock mates on same body pair -> one is redundant."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
||||
ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert len(result.per_mate) == 2
|
||||
redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant)
|
||||
# At least one should be redundant
|
||||
assert redundant_count >= 1
|
||||
assert result.assembly.redundant_count > 0
|
||||
|
||||
def test_hinge_pattern_labeling(self) -> None:
|
||||
"""Hinge mates get pattern membership."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
||||
),
|
||||
Mate(
|
||||
mate_id=1,
|
||||
mate_type=MateType.COINCIDENT,
|
||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert len(result.per_mate) == 2
|
||||
# Both mates should be part of the hinge pattern
|
||||
for ml in result.per_mate:
|
||||
assert ml.pattern is JointPattern.HINGE
|
||||
assert ml.is_independent is True
|
||||
|
||||
def test_grounded_assembly(self) -> None:
|
||||
"""Grounded assembly labeling works."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates, ground_body=0)
|
||||
assert result.assembly.is_rigid
|
||||
|
||||
def test_empty_mates(self) -> None:
|
||||
"""No mates -> no per_mate labels, underconstrained."""
|
||||
bodies = _make_bodies(2)
|
||||
result = label_mate_assembly(bodies, [])
|
||||
assert len(result.per_mate) == 0
|
||||
assert result.assembly.classification == "underconstrained"
|
||||
|
||||
def test_assembly_classification(self) -> None:
|
||||
"""Assembly classification is present."""
|
||||
bodies = _make_bodies(2)
|
||||
mates = [
|
||||
Mate(
|
||||
mate_id=0,
|
||||
mate_type=MateType.LOCK,
|
||||
ref_a=_make_ref(0, GeometryType.FACE),
|
||||
ref_b=_make_ref(1, GeometryType.FACE),
|
||||
),
|
||||
]
|
||||
result = label_mate_assembly(bodies, mates)
|
||||
assert result.assembly.classification in {
|
||||
"well-constrained",
|
||||
"overconstrained",
|
||||
"underconstrained",
|
||||
"mixed",
|
||||
}
|
||||
285
tests/mates/test_patterns.py
Normal file
285
tests/mates/test_patterns.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""Tests for solver.mates.patterns -- joint pattern recognition."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.types import JointType
|
||||
from solver.mates.patterns import JointPattern, PatternMatch, recognize_patterns
|
||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ref(
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
*,
|
||||
geometry_id: str = "Geom001",
|
||||
origin: np.ndarray | None = None,
|
||||
direction: np.ndarray | None = None,
|
||||
) -> GeometryRef:
|
||||
"""Factory for GeometryRef with sensible defaults."""
|
||||
if origin is None:
|
||||
origin = np.zeros(3)
|
||||
if direction is None and geom_type in {
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}:
|
||||
direction = np.array([0.0, 0.0, 1.0])
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id=geometry_id,
|
||||
origin=origin,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
def _make_mate(
|
||||
mate_id: int,
|
||||
mate_type: MateType,
|
||||
body_a: int,
|
||||
body_b: int,
|
||||
geom_a: GeometryType = GeometryType.FACE,
|
||||
geom_b: GeometryType = GeometryType.FACE,
|
||||
) -> Mate:
|
||||
"""Factory for Mate with body pair and geometry types."""
|
||||
return Mate(
|
||||
mate_id=mate_id,
|
||||
mate_type=mate_type,
|
||||
ref_a=_make_ref(body_a, geom_a),
|
||||
ref_b=_make_ref(body_b, geom_b),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# JointPattern enum
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJointPattern:
|
||||
"""JointPattern enum."""
|
||||
|
||||
def test_member_count(self) -> None:
|
||||
assert len(JointPattern) == 9
|
||||
|
||||
def test_string_values(self) -> None:
|
||||
for jp in JointPattern:
|
||||
assert isinstance(jp.value, str)
|
||||
|
||||
def test_access_by_name(self) -> None:
|
||||
assert JointPattern["HINGE"] is JointPattern.HINGE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# PatternMatch
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestPatternMatch:
|
||||
"""PatternMatch dataclass."""
|
||||
|
||||
def test_construction(self) -> None:
|
||||
mate = _make_mate(0, MateType.LOCK, 0, 1)
|
||||
pm = PatternMatch(
|
||||
pattern=JointPattern.FIXED,
|
||||
mates=[mate],
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
confidence=1.0,
|
||||
equivalent_joint_type=JointType.FIXED,
|
||||
)
|
||||
assert pm.pattern is JointPattern.FIXED
|
||||
assert pm.confidence == 1.0
|
||||
assert pm.missing_mates == []
|
||||
|
||||
def test_to_dict(self) -> None:
|
||||
mate = _make_mate(5, MateType.LOCK, 0, 1)
|
||||
pm = PatternMatch(
|
||||
pattern=JointPattern.FIXED,
|
||||
mates=[mate],
|
||||
body_a=0,
|
||||
body_b=1,
|
||||
confidence=1.0,
|
||||
equivalent_joint_type=JointType.FIXED,
|
||||
)
|
||||
d = pm.to_dict()
|
||||
assert d["pattern"] == "fixed"
|
||||
assert d["mate_ids"] == [5]
|
||||
assert d["equivalent_joint_type"] == "FIXED"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recognize_patterns — canonical patterns
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecognizeCanonical:
|
||||
"""Full-confidence canonical pattern recognition."""
|
||||
|
||||
def test_empty_input(self) -> None:
|
||||
assert recognize_patterns([]) == []
|
||||
|
||||
def test_hinge(self) -> None:
|
||||
"""Concentric(axis) + Coincident(plane) -> Hinge."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
top = results[0]
|
||||
assert top.pattern is JointPattern.HINGE
|
||||
assert top.confidence == 1.0
|
||||
assert top.equivalent_joint_type is JointType.REVOLUTE
|
||||
assert top.missing_mates == []
|
||||
|
||||
def test_slider(self) -> None:
|
||||
"""Coincident(plane) + Parallel(axis) -> Slider."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
||||
_make_mate(1, MateType.PARALLEL, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
top = results[0]
|
||||
assert top.pattern is JointPattern.SLIDER
|
||||
assert top.confidence == 1.0
|
||||
assert top.equivalent_joint_type is JointType.SLIDER
|
||||
|
||||
def test_cylinder(self) -> None:
|
||||
"""Concentric(axis) only -> Cylinder."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
# Should match cylinder at confidence 1.0
|
||||
cylinder = [r for r in results if r.pattern is JointPattern.CYLINDER]
|
||||
assert len(cylinder) >= 1
|
||||
assert cylinder[0].confidence == 1.0
|
||||
assert cylinder[0].equivalent_joint_type is JointType.CYLINDRICAL
|
||||
|
||||
def test_ball(self) -> None:
|
||||
"""Coincident(point) -> Ball."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.POINT, GeometryType.POINT),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
top = results[0]
|
||||
assert top.pattern is JointPattern.BALL
|
||||
assert top.confidence == 1.0
|
||||
assert top.equivalent_joint_type is JointType.BALL
|
||||
|
||||
def test_planar_face(self) -> None:
|
||||
"""Coincident(face) -> Planar."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
top = results[0]
|
||||
assert top.pattern is JointPattern.PLANAR
|
||||
assert top.confidence == 1.0
|
||||
assert top.equivalent_joint_type is JointType.PLANAR
|
||||
|
||||
def test_fixed(self) -> None:
|
||||
"""Lock -> Fixed."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.LOCK, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
top = results[0]
|
||||
assert top.pattern is JointPattern.FIXED
|
||||
assert top.confidence == 1.0
|
||||
assert top.equivalent_joint_type is JointType.FIXED
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recognize_patterns — partial matches
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecognizePartial:
|
||||
"""Partial pattern matches and hints."""
|
||||
|
||||
def test_concentric_without_plane_hints_hinge(self) -> None:
|
||||
"""Concentric alone matches hinge at 0.5 confidence with missing hint."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
|
||||
assert len(hinge_matches) >= 1
|
||||
hinge = hinge_matches[0]
|
||||
assert hinge.confidence == 0.5
|
||||
assert len(hinge.missing_mates) > 0
|
||||
|
||||
def test_coincident_plane_without_parallel_hints_slider(self) -> None:
|
||||
"""Coincident(plane) alone matches slider at 0.5 confidence."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
slider_matches = [r for r in results if r.pattern is JointPattern.SLIDER]
|
||||
assert len(slider_matches) >= 1
|
||||
assert slider_matches[0].confidence == 0.5
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# recognize_patterns — ambiguous / multi-body
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecognizeAmbiguous:
|
||||
"""Ambiguous patterns and multi-body-pair assemblies."""
|
||||
|
||||
def test_concentric_matches_both_hinge_and_cylinder(self) -> None:
|
||||
"""A single concentric mate produces both hinge (partial) and cylinder matches."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
patterns = {r.pattern for r in results}
|
||||
assert JointPattern.HINGE in patterns
|
||||
assert JointPattern.CYLINDER in patterns
|
||||
|
||||
def test_multiple_body_pairs(self) -> None:
|
||||
"""Mates across different body pairs produce separate pattern matches."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.LOCK, 0, 1),
|
||||
_make_mate(1, MateType.COINCIDENT, 2, 3, GeometryType.POINT, GeometryType.POINT),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
pairs = {(r.body_a, r.body_b) for r in results}
|
||||
assert (0, 1) in pairs
|
||||
assert (2, 3) in pairs
|
||||
|
||||
def test_results_sorted_by_confidence(self) -> None:
|
||||
"""All results should be sorted by confidence descending."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_make_mate(1, MateType.LOCK, 2, 3),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
confidences = [r.confidence for r in results]
|
||||
assert confidences == sorted(confidences, reverse=True)
|
||||
|
||||
def test_unknown_pattern(self) -> None:
|
||||
"""A mate type that matches no rule returns UNKNOWN."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.ANGLE, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
assert any(r.pattern is JointPattern.UNKNOWN for r in results)
|
||||
|
||||
def test_body_pair_normalization(self) -> None:
|
||||
"""Mates with reversed body order should be grouped together."""
|
||||
mates = [
|
||||
_make_mate(0, MateType.CONCENTRIC, 1, 0, GeometryType.AXIS, GeometryType.AXIS),
|
||||
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
||||
]
|
||||
results = recognize_patterns(mates)
|
||||
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
|
||||
assert len(hinge_matches) >= 1
|
||||
assert hinge_matches[0].confidence == 1.0
|
||||
329
tests/mates/test_primitives.py
Normal file
329
tests/mates/test_primitives.py
Normal file
@@ -0,0 +1,329 @@
|
||||
"""Tests for solver.mates.primitives -- mate type definitions."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import ClassVar
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from solver.mates.primitives import (
|
||||
GeometryRef,
|
||||
GeometryType,
|
||||
Mate,
|
||||
MateType,
|
||||
dof_removed,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_ref(
|
||||
body_id: int,
|
||||
geom_type: GeometryType,
|
||||
*,
|
||||
geometry_id: str = "Geom001",
|
||||
origin: np.ndarray | None = None,
|
||||
direction: np.ndarray | None = None,
|
||||
) -> GeometryRef:
|
||||
"""Factory for GeometryRef with sensible defaults."""
|
||||
if origin is None:
|
||||
origin = np.zeros(3)
|
||||
if direction is None and geom_type in {
|
||||
GeometryType.FACE,
|
||||
GeometryType.AXIS,
|
||||
GeometryType.PLANE,
|
||||
}:
|
||||
direction = np.array([0.0, 0.0, 1.0])
|
||||
return GeometryRef(
|
||||
body_id=body_id,
|
||||
geometry_type=geom_type,
|
||||
geometry_id=geometry_id,
|
||||
origin=origin,
|
||||
direction=direction,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# MateType
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateType:
|
||||
"""MateType enum construction and DOF values."""
|
||||
|
||||
EXPECTED_DOF: ClassVar[dict[str, int]] = {
|
||||
"COINCIDENT": 3,
|
||||
"CONCENTRIC": 2,
|
||||
"PARALLEL": 2,
|
||||
"PERPENDICULAR": 1,
|
||||
"TANGENT": 1,
|
||||
"DISTANCE": 1,
|
||||
"ANGLE": 1,
|
||||
"LOCK": 6,
|
||||
}
|
||||
|
||||
def test_member_count(self) -> None:
|
||||
assert len(MateType) == 8
|
||||
|
||||
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
|
||||
def test_default_dof_values(self, name: str, dof: int) -> None:
|
||||
assert MateType[name].default_dof == dof
|
||||
|
||||
def test_value_is_tuple(self) -> None:
|
||||
assert MateType.COINCIDENT.value == (0, 3)
|
||||
assert MateType.COINCIDENT.default_dof == 3
|
||||
|
||||
def test_access_by_name(self) -> None:
|
||||
assert MateType["LOCK"] is MateType.LOCK
|
||||
|
||||
def test_no_alias_collision(self) -> None:
|
||||
ordinals = [m.value[0] for m in MateType]
|
||||
assert len(ordinals) == len(set(ordinals))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeometryType
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGeometryType:
|
||||
"""GeometryType enum."""
|
||||
|
||||
def test_member_count(self) -> None:
|
||||
assert len(GeometryType) == 5
|
||||
|
||||
def test_string_values(self) -> None:
|
||||
for gt in GeometryType:
|
||||
assert isinstance(gt.value, str)
|
||||
assert gt.value == gt.name.lower()
|
||||
|
||||
def test_access_by_name(self) -> None:
|
||||
assert GeometryType["FACE"] is GeometryType.FACE
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# GeometryRef
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGeometryRef:
|
||||
"""GeometryRef dataclass."""
|
||||
|
||||
def test_construction(self) -> None:
|
||||
ref = _make_ref(0, GeometryType.AXIS, geometry_id="Axis001")
|
||||
assert ref.body_id == 0
|
||||
assert ref.geometry_type is GeometryType.AXIS
|
||||
assert ref.geometry_id == "Axis001"
|
||||
np.testing.assert_array_equal(ref.origin, np.zeros(3))
|
||||
assert ref.direction is not None
|
||||
|
||||
def test_default_direction_none(self) -> None:
|
||||
ref = GeometryRef(
|
||||
body_id=0,
|
||||
geometry_type=GeometryType.POINT,
|
||||
geometry_id="Point001",
|
||||
)
|
||||
assert ref.direction is None
|
||||
|
||||
def test_to_dict_round_trip(self) -> None:
|
||||
ref = _make_ref(
|
||||
1,
|
||||
GeometryType.FACE,
|
||||
origin=np.array([1.0, 2.0, 3.0]),
|
||||
direction=np.array([0.0, 1.0, 0.0]),
|
||||
)
|
||||
d = ref.to_dict()
|
||||
restored = GeometryRef.from_dict(d)
|
||||
assert restored.body_id == ref.body_id
|
||||
assert restored.geometry_type is ref.geometry_type
|
||||
assert restored.geometry_id == ref.geometry_id
|
||||
np.testing.assert_array_almost_equal(restored.origin, ref.origin)
|
||||
assert restored.direction is not None
|
||||
np.testing.assert_array_almost_equal(restored.direction, ref.direction)
|
||||
|
||||
def test_to_dict_with_none_direction(self) -> None:
|
||||
ref = GeometryRef(
|
||||
body_id=2,
|
||||
geometry_type=GeometryType.POINT,
|
||||
geometry_id="Point002",
|
||||
origin=np.array([5.0, 6.0, 7.0]),
|
||||
)
|
||||
d = ref.to_dict()
|
||||
assert d["direction"] is None
|
||||
restored = GeometryRef.from_dict(d)
|
||||
assert restored.direction is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMate:
|
||||
"""Mate dataclass."""
|
||||
|
||||
def test_construction(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
||||
assert m.mate_id == 0
|
||||
assert m.mate_type is MateType.COINCIDENT
|
||||
|
||||
def test_value_default_zero(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
||||
assert m.value == 0.0
|
||||
|
||||
def test_tolerance_default(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
||||
assert m.tolerance == 1e-6
|
||||
|
||||
def test_to_dict_round_trip(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.AXIS, origin=np.array([1.0, 0.0, 0.0]))
|
||||
ref_b = _make_ref(1, GeometryType.AXIS, origin=np.array([2.0, 0.0, 0.0]))
|
||||
m = Mate(
|
||||
mate_id=5,
|
||||
mate_type=MateType.CONCENTRIC,
|
||||
ref_a=ref_a,
|
||||
ref_b=ref_b,
|
||||
value=0.0,
|
||||
tolerance=1e-8,
|
||||
)
|
||||
d = m.to_dict()
|
||||
restored = Mate.from_dict(d)
|
||||
assert restored.mate_id == m.mate_id
|
||||
assert restored.mate_type is m.mate_type
|
||||
assert restored.ref_a.body_id == m.ref_a.body_id
|
||||
assert restored.ref_b.body_id == m.ref_b.body_id
|
||||
assert restored.value == m.value
|
||||
assert restored.tolerance == m.tolerance
|
||||
|
||||
def test_from_dict_missing_optional(self) -> None:
|
||||
d = {
|
||||
"mate_id": 1,
|
||||
"mate_type": "DISTANCE",
|
||||
"ref_a": _make_ref(0, GeometryType.POINT).to_dict(),
|
||||
"ref_b": _make_ref(1, GeometryType.POINT).to_dict(),
|
||||
}
|
||||
m = Mate.from_dict(d)
|
||||
assert m.value == 0.0
|
||||
assert m.tolerance == 1e-6
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# dof_removed
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDofRemoved:
|
||||
"""Context-dependent DOF removal counts."""
|
||||
|
||||
def test_coincident_face_face(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
|
||||
|
||||
def test_coincident_point_point(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.POINT)
|
||||
ref_b = _make_ref(1, GeometryType.POINT)
|
||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
|
||||
|
||||
def test_coincident_edge_edge(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.EDGE)
|
||||
ref_b = _make_ref(1, GeometryType.EDGE)
|
||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 2
|
||||
|
||||
def test_coincident_face_point(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.POINT)
|
||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 1
|
||||
|
||||
def test_concentric_axis_axis(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
||||
assert dof_removed(MateType.CONCENTRIC, ref_a, ref_b) == 2
|
||||
|
||||
def test_lock_any(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.POINT)
|
||||
assert dof_removed(MateType.LOCK, ref_a, ref_b) == 6
|
||||
|
||||
def test_distance_any(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.POINT)
|
||||
ref_b = _make_ref(1, GeometryType.EDGE)
|
||||
assert dof_removed(MateType.DISTANCE, ref_a, ref_b) == 1
|
||||
|
||||
def test_unknown_combo_uses_default(self) -> None:
|
||||
"""Unlisted geometry combos fall back to default_dof."""
|
||||
ref_a = _make_ref(0, GeometryType.EDGE)
|
||||
ref_b = _make_ref(1, GeometryType.POINT)
|
||||
result = dof_removed(MateType.COINCIDENT, ref_a, ref_b)
|
||||
assert result == MateType.COINCIDENT.default_dof
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Mate.validate
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestMateValidation:
|
||||
"""Mate.validate() compatibility checks."""
|
||||
|
||||
def test_valid_concentric(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
||||
m.validate() # should not raise
|
||||
|
||||
def test_invalid_concentric_face(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
||||
with pytest.raises(ValueError, match="CONCENTRIC"):
|
||||
m.validate()
|
||||
|
||||
def test_valid_coincident_face_face(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
||||
m.validate() # should not raise
|
||||
|
||||
def test_invalid_self_mate(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.FACE)
|
||||
ref_b = _make_ref(0, GeometryType.FACE, geometry_id="Face002")
|
||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
||||
with pytest.raises(ValueError, match="Self-mate"):
|
||||
m.validate()
|
||||
|
||||
def test_invalid_parallel_point(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.POINT)
|
||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
||||
m = Mate(mate_id=0, mate_type=MateType.PARALLEL, ref_a=ref_a, ref_b=ref_b)
|
||||
with pytest.raises(ValueError, match="PARALLEL"):
|
||||
m.validate()
|
||||
|
||||
def test_invalid_tangent_axis(self) -> None:
|
||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
||||
ref_b = _make_ref(1, GeometryType.FACE)
|
||||
m = Mate(mate_id=0, mate_type=MateType.TANGENT, ref_a=ref_a, ref_b=ref_b)
|
||||
with pytest.raises(ValueError, match="TANGENT"):
|
||||
m.validate()
|
||||
|
||||
def test_missing_direction_for_axis(self) -> None:
|
||||
ref_a = GeometryRef(
|
||||
body_id=0,
|
||||
geometry_type=GeometryType.AXIS,
|
||||
geometry_id="Axis001",
|
||||
origin=np.zeros(3),
|
||||
direction=None, # missing!
|
||||
)
|
||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
||||
with pytest.raises(ValueError, match="direction"):
|
||||
m.validate()
|
||||
Reference in New Issue
Block a user