26 Commits

Author SHA1 Message Date
forbes
f68245e952 ci: install torch separately to avoid --index-url replacing PyPI
Some checks failed
CI / lint (push) Successful in 24s
CI / type-check (push) Successful in 3m27s
CI / test (push) Failing after 5m33s
CI / datagen (push) Has been skipped
2026-02-03 18:35:17 -06:00
forbes
b088b74dcf ci: remove internal CA dependency for DMZ-compatible public branch
Some checks failed
CI / lint (push) Successful in 25s
CI / type-check (push) Failing after 19s
CI / test (push) Failing after 6m8s
CI / datagen (push) Has been skipped
Strip ipa.kindred.internal CA trust steps. Replace Node-dependent
actions with raw git commands for act runner compatibility.
2026-02-03 18:10:04 -06:00
forbes
c728bd93f7 Merge remote-tracking branch 'public/main'
Some checks failed
CI / datagen (push) Blocked by required conditions
CI / lint (push) Failing after 2m20s
CI / test (push) Has been cancelled
CI / type-check (push) Has been cancelled
2026-02-03 18:03:54 -06:00
forbes
bbbc5e0137 ci: use venv for PEP 668 compatibility on runner 2026-02-03 17:59:05 -06:00
forbes
40cda51142 ci: install internal CA from IPA instead of skipping SSL verification
Fetches the Kindred CA cert from ipa.kindred.internal and installs it
into the system trust store before checkout. Removes GIT_SSL_NO_VERIFY.
2026-02-03 17:57:53 -06:00
forbes
e45207b7cc ci: skip SSL verification for internal Gitea runner 2026-02-03 17:56:13 -06:00
forbes
537d8c7689 ci: add datagen job, adapt workflow for Gitea runner
- Drop actions/setup-python, use system python3
- Use full Gitea-compatible action URLs
- CPU-only torch via pytorch whl/cpu index
- Add datagen job with cache/checkpoint resume and artifact upload
- Manual dispatch with configurable assembly count and worker count
- Datagen runs on push to main (after tests pass) or manual trigger
2026-02-03 17:52:48 -06:00
93bda28f67 feat(mates): add mate-level ground truth labels
Some checks failed
CI / lint (push) Successful in 1m45s
CI / type-check (push) Successful in 2m32s
CI / test (push) Failing after 3m36s
MateLabel and MateAssemblyLabels dataclasses with label_mate_assembly()
that back-attributes joint-level independence to originating mates.
Detects redundant and degenerate mates with pattern membership tracking.

Closes #15
2026-02-03 13:08:23 -06:00
239e45c7f9 feat(mates): add mate-based synthetic assembly generator
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
SyntheticMateGenerator wraps existing joint generator with reverse
mapping (joint->mates) and configurable noise injection (redundant,
missing, incompatible mates). Batch generation via
generate_mate_training_batch().

Closes #14
2026-02-03 13:05:58 -06:00
118474f892 feat(mates): add mate-to-joint conversion and assembly analysis
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
convert_mates_to_joints() bridges mate-level constraints to the existing
joint-based analysis pipeline. analyze_mate_assembly() orchestrates the
full pipeline with bidirectional mate-joint traceability.

Closes #13
2026-02-03 13:03:13 -06:00
e8143cf64c feat(mates): add joint pattern recognition
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
JointPattern enum (9 patterns), PatternMatch dataclass, and
recognize_patterns() function with data-driven pattern rules.
Supports canonical, partial, and ambiguous pattern matching.

Closes #12
2026-02-03 12:59:53 -06:00
9f53fdb154 feat(mates): add mate type definitions and geometry references
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
MateType enum (8 types), GeometryType enum (5 types), GeometryRef and
Mate dataclasses with validation, serialization, and context-dependent
DOF removal via dof_removed().

Closes #11
2026-02-03 12:55:37 -06:00
5d1988b513 Merge remote-tracking branch 'public/main'
Some checks failed
CI / lint (push) Successful in 38s
CI / type-check (push) Successful in 1m47s
CI / test (push) Failing after 3m2s
# Conflicts:
#	.gitignore
#	README.md
2026-02-03 10:53:48 -06:00
f29060491e feat(datagen): add dataset generation CLI with sharding and checkpointing
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator,
  ShardSpec/ShardResult dataclasses, parallel shard generation via
  ProcessPoolExecutor, checkpoint/resume support, index and stats output
- Add scripts/generate_synthetic.py CLI entry point with Hydra-first
  and argparse fallback modes
- Add minimal YAML parser (parse_simple_yaml) for config loading
  without PyYAML dependency
- Add progress display with tqdm fallback to print-based ETA
- Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every
- Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator
  exports
- Add tests/datagen/test_dataset.py with 28 tests covering config,
  YAML parsing, seed derivation, end-to-end generation, resume,
  stats/index structure, determinism, and CLI integration

Closes #10
2026-02-03 08:44:31 -06:00
8a49f8ef40 feat: ground truth labeling pipeline
Some checks failed
CI / lint (push) Failing after 25m6s
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
- Create solver/datagen/labeling.py with label_assembly() function
- Add dataclasses: ConstraintLabel, JointLabel, BodyDofLabel,
  AssemblyLabel, AssemblyLabels
- Per-constraint labels: pebble_independent + jacobian_independent
- Per-joint labels: aggregated independent/redundant/total counts
- Per-body DOF: translational + rotational from nullspace projection
- Assembly label: classification, total_dof, has_degeneracy flag
- AssemblyLabels.to_dict() for JSON-serializable output
- Integrate into generate_training_batch (adds 'labels' field)
- Export AssemblyLabels and label_assembly from datagen package
- Add 25 labeling tests + 1 batch structure test (184 total)

Closes #9
2026-02-02 15:20:02 -06:00
78289494e2 feat: geometric diversity for synthetic assembly generation
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
- Add AxisStrategy type (cardinal, random, near_parallel)
- Add random body orientations via scipy.spatial.transform.Rotation
- Add parallel axis injection with configurable probability
- Add grounded parameter on all 7 generators (grounded/floating)
- Add axis sampling strategies: cardinal, random, near-parallel
- Update _create_joint with orientation-aware anchor offsets
- Add _resolve_axis helper for parallel axis propagation
- Update generate_training_batch with axis_strategy, parallel_axis_prob,
  grounded_ratio parameters
- Add body_orientations and grounded fields to batch output
- Export AxisStrategy from datagen package
- Add 28 new tests (72 total generator tests, 158 total)

Closes #8
2026-02-02 14:57:49 -06:00
0b5813b5a9 feat: parameterized assembly templates and complexity tiers
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Add 4 new topology generators to SyntheticAssemblyGenerator:
- generate_tree_assembly: random spanning tree with configurable branching
- generate_loop_assembly: closed ring producing overconstrained data
- generate_star_assembly: hub-and-spoke topology
- generate_mixed_assembly: tree + loops with configurable edge density

Each accepts joint_types as JointType | list[JointType] for per-joint
type sampling.

Add complexity tiers (simple/medium/complex) with predefined body count
ranges via COMPLEXITY_RANGES dict and ComplexityTier type alias.

Update generate_training_batch with 7-way generator selection,
complexity_tier parameter, and generator_type field in output dicts.

Extract private helpers (_random_position, _random_axis,
_select_joint_type, _create_joint) to reduce duplication.

44 generator tests, 130 total — all passing.

Closes #7
2026-02-02 14:38:05 -06:00
dc742bfc82 test: add unit tests for datagen modules
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
- test_types.py: JointType enum values/count, dataclass defaults/isolation
- test_pebble_game.py: DOF accounting, rigidity, classification, edge results
- test_jacobian.py: Jacobian shape per joint type, rank, parallel axis degeneracy
- test_analysis.py: demo scenarios (revolute, fixed, triangle, parallel axes)
- test_generator.py: chain/rigid/overconstrained generation, training batch

Bug fixes found during testing:
- JointType enum: duplicate int values caused aliasing (SLIDER=REVOLUTE etc).
  Changed to (ordinal, dof) tuple values with a .dof property.
- pebble_game.py: .value -> .dof for constraint count
- analysis.py: classify from effective DOF (not raw pebble game with virtual
  ground body skew)

105 tests, all passing.

Closes #6
2026-02-02 14:08:22 -06:00
831a10cdb4 feat: port SyntheticAssemblyGenerator to solver/datagen/generator.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Port chain, rigid, and overconstrained assembly generators plus
the training batch generation from data/synthetic/pebble-game.py.

- Refactored rng.choice on enums/callables to integer indexing (mypy)
- Typed n_bodies_range as tuple[int, int]
- Typed batch return as list[dict[str, Any]]
- Full type annotations (mypy strict)
- Re-exported from solver.datagen.__init__

Closes #5
2026-02-02 13:54:32 -06:00
9a31df4988 feat: port analyze_assembly to solver/datagen/analysis.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Port the combined pebble game + Jacobian verification entry point from
data/synthetic/pebble-game.py. Ties PebbleGame3D and JacobianVerifier
together with virtual ground body support.

- Optional[int] -> int | None (UP007)
- GROUND_ID constant extracted to module level
- Full type annotations (mypy strict)
- Re-exported from solver.datagen.__init__

Closes #4
2026-02-02 13:52:03 -06:00
455b6318d9 feat: port JacobianVerifier to solver/datagen/jacobian.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Port the constraint Jacobian builder and numerical rank verifier from
data/synthetic/pebble-game.py. All 11 joint type builders, SVD rank
computation, and incremental dependency detection.

- Full type annotations (mypy strict)
- Ruff lint and format clean
- Re-exported from solver.datagen.__init__

Closes #3
2026-02-02 13:50:16 -06:00
35d4ef736f feat: port PebbleGame3D to solver/datagen/pebble_game.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Port the (6,6)-pebble game implementation from data/synthetic/pebble-game.py.
Imports shared types from solver.datagen.types. No behavioral changes.

- Full type annotations on all methods (mypy strict)
- Ruff-compliant: ternary, combined if, unpacking
- Re-exported from solver.datagen.__init__

Closes #2
2026-02-02 13:47:36 -06:00
1b6135129e feat: port shared types to solver/datagen/types.py
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
Port JointType, RigidBody, Joint, PebbleState, and ConstraintAnalysis
from data/synthetic/pebble-game.py into the solver package.

- Add __all__ export list
- Put typing.Any behind TYPE_CHECKING (ruff TCH003)
- Parameterize list[dict] as list[dict[str, Any]] (mypy strict)
- Re-export all types from solver.datagen.__init__

Closes #1
2026-02-02 13:43:19 -06:00
363b49281b build: phase 0 infrastructure setup
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
- Project structure: solver/, freecad/, export/, configs/, scripts/, tests/, docs/
- pyproject.toml with dependency groups: core, train, freecad, dev
- Hydra configs: dataset (synthetic, fusion360), model (baseline, gat), training (pretrain, finetune), export (production)
- Dockerfile with CUDA+PyG GPU and CPU-only targets
- docker-compose.yml for train, test, data-gen services
- Makefile with targets: train, test, lint, format, type-check, data-gen, export, check
- Pre-commit hooks: ruff, mypy, conventional commits
- Gitea Actions CI: lint, type-check, test on push/PR
- README with setup and usage instructions
2026-02-02 13:26:38 -06:00
f61d005400 first commit 2026-02-02 13:09:37 -06:00
forbes
e32c9cd793 fix: use previous iteration dxNorm in convergence check
isConvergedToNumericalLimit() compared dxNorms->at(iterNo) to itself
instead of comparing current vs previous iteration. This prevented
the solver from detecting convergence improvement, causing it to
exhaust its iteration limit on assemblies with many constraints.

Fix: read dxNorms->at(iterNo - 1) for the previous iteration's norm.
2026-02-01 21:10:12 -06:00
62 changed files with 8866 additions and 25 deletions

180
.gitea/workflows/ci.yaml Normal file
View File

@@ -0,0 +1,180 @@
name: CI
on:
push:
branches: [main, public]
pull_request:
branches: [main, public]
workflow_dispatch:
inputs:
run_datagen:
description: "Run dataset generation"
required: false
type: boolean
default: false
num_assemblies:
description: "Number of assemblies to generate"
required: false
type: string
default: "100000"
num_workers:
description: "Parallel workers for datagen"
required: false
type: string
default: "4"
env:
PIP_CACHE_DIR: /tmp/pip-cache-solver
TORCH_INDEX: https://download.pytorch.org/whl/cpu
VIRTUAL_ENV: /tmp/solver-venv
jobs:
# ---------------------------------------------------------------------------
# Lint — fast, no torch required
# ---------------------------------------------------------------------------
lint:
runs-on: ubuntu-latest
env:
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
steps:
- name: Checkout
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout "$GITHUB_SHA" 2>/dev/null || true
- name: Set up venv
run: python3 -m venv $VIRTUAL_ENV
- name: Install lint tools
run: pip install --cache-dir $PIP_CACHE_DIR ruff
- name: Ruff check
run: ruff check solver/ freecad/ tests/ scripts/
- name: Ruff format check
run: ruff format --check solver/ freecad/ tests/ scripts/
# ---------------------------------------------------------------------------
# Type check
# ---------------------------------------------------------------------------
type-check:
runs-on: ubuntu-latest
env:
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
steps:
- name: Checkout
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout "$GITHUB_SHA" 2>/dev/null || true
- name: Set up venv
run: python3 -m venv $VIRTUAL_ENV
- name: Install dependencies
run: |
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
pip install --cache-dir $PIP_CACHE_DIR mypy numpy scipy
pip install --cache-dir $PIP_CACHE_DIR -e ".[dev]"
- name: Mypy
run: mypy solver/ freecad/
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
test:
runs-on: ubuntu-latest
env:
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
steps:
- name: Checkout
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout "$GITHUB_SHA" 2>/dev/null || true
- name: Set up venv
run: python3 -m venv $VIRTUAL_ENV
- name: Install dependencies
run: |
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
pip install --cache-dir $PIP_CACHE_DIR -e ".[train,dev]"
- name: Run tests
run: pytest tests/ freecad/tests/ -v --tb=short
# ---------------------------------------------------------------------------
# Dataset generation — manual trigger or on main/public push
# ---------------------------------------------------------------------------
datagen:
runs-on: ubuntu-latest
if: >-
(github.event_name == 'workflow_dispatch' && inputs.run_datagen == true) ||
(github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/public'))
needs: [test]
env:
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
steps:
- name: Checkout
run: |
git config --global --add safe.directory "$GITHUB_WORKSPACE"
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
cd "$GITHUB_WORKSPACE"
git checkout "$GITHUB_SHA" 2>/dev/null || true
- name: Set up venv
run: python3 -m venv $VIRTUAL_ENV
- name: Install dependencies
run: |
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
pip install --cache-dir $PIP_CACHE_DIR -e ".[train]"
- name: Generate dataset
run: |
NUM=${INPUTS_NUM_ASSEMBLIES:-100000}
WORKERS=${INPUTS_NUM_WORKERS:-4}
echo "Generating ${NUM} assemblies with ${WORKERS} workers"
python3 scripts/generate_synthetic.py \
--num-assemblies "${NUM}" \
--num-workers "${WORKERS}" \
--output-dir data/synthetic
env:
INPUTS_NUM_ASSEMBLIES: ${{ inputs.num_assemblies }}
INPUTS_NUM_WORKERS: ${{ inputs.num_workers }}
- name: Print summary
if: always()
run: |
echo "=== Dataset Generation Results ==="
if [ -f data/synthetic/stats.json ]; then
python3 -c "
import json
with open('data/synthetic/stats.json') as f:
s = json.load(f)
print(f'Total examples: {s[\"total_examples\"]}')
print(f'Classification: {json.dumps(s[\"classification_distribution\"], indent=2)}')
print(f'Rigid: {s[\"rigidity\"][\"rigid_fraction\"]*100:.1f}%')
print(f'Degeneracy: {s[\"geometric_degeneracy\"][\"fraction_with_degeneracy\"]*100:.1f}%')
"
else
echo "stats.json not found — generation may have failed"
ls -la data/synthetic/ 2>/dev/null || echo "output dir missing"
fi

77
.gitignore vendored
View File

@@ -1,44 +1,83 @@
# Prerequisites
# C++ compiled objects
*.d
# Compiled Object files
*.slo
*.lo
*.o
*.obj
# Precompiled Headers
*.gch
*.pch
# Compiled Dynamic libraries
# C++ libraries
*.so
*.dylib
*.dll
# Fortran module files
*.mod
*.smod
# Compiled Static libraries
*.lai
*.la
*.a
*.lib
# Executables
# C++ executables
*.exe
*.out
*.app
.vs
# C++ build
build/
cmake-build-debug/
.vs/
x64/
temp/
# OndselSolver test artifacts
*.bak
assembly.asmt
build
cmake-build-debug
.idea
temp/
/testapp/draggingBackhoe.log
/testapp/runPreDragBackhoe.asmt
# Python
__pycache__/
*.py[cod]
*$py.class
*.egg-info/
dist/
*.egg
# Virtual environments
.venv/
venv/
# IDE
.vscode/
.idea/
*.swp
*.swo
*~
# mypy / ruff / pytest
.mypy_cache/
.ruff_cache/
.pytest_cache/
# Data (large files tracked separately)
data/synthetic/*.pt
data/fusion360/*.json
data/fusion360/*.step
data/processed/*.pt
!data/**/.gitkeep
# Model checkpoints
*.ckpt
*.pth
*.onnx
*.torchscript
# Experiment tracking
wandb/
runs/
# OS
.DS_Store
Thumbs.db
# Environment
.env

23
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,23 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.3.4
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
- id: mypy
additional_dependencies:
- torch>=2.2
- numpy>=1.26
args: [--ignore-missing-imports]
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v3.1.0
hooks:
- id: conventional-pre-commit
stages: [commit-msg]
args: [feat, fix, docs, style, refactor, perf, test, build, ci, chore, revert]

61
Dockerfile Normal file
View File

@@ -0,0 +1,61 @@
FROM nvidia/cuda:12.4.1-devel-ubuntu22.04 AS base
ENV DEBIAN_FRONTEND=noninteractive
ENV PYTHONUNBUFFERED=1
# System deps
RUN apt-get update && apt-get install -y --no-install-recommends \
python3.11 python3.11-venv python3.11-dev python3-pip \
git wget curl \
# FreeCAD headless deps
freecad \
libgl1-mesa-glx libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
RUN update-alternatives --install /usr/bin/python python /usr/bin/python3.11 1
# Create venv
RUN python -m venv /opt/venv
ENV PATH="/opt/venv/bin:$PATH"
# Install PyTorch with CUDA
RUN pip install --no-cache-dir \
torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu124
# Install PyG
RUN pip install --no-cache-dir \
torch-geometric \
pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv \
-f https://data.pyg.org/whl/torch-2.4.0+cu124.html
WORKDIR /workspace
# Install project
COPY pyproject.toml .
RUN pip install --no-cache-dir -e ".[train,dev]" || true
COPY . .
RUN pip install --no-cache-dir -e ".[train,dev]"
# -------------------------------------------------------------------
FROM base AS cpu
# CPU-only variant (for CI and non-GPU environments)
FROM python:3.11-slim AS cpu-only
ENV PYTHONUNBUFFERED=1
RUN apt-get update && apt-get install -y --no-install-recommends \
git freecad libgl1-mesa-glx libglib2.0-0 \
&& rm -rf /var/lib/apt/lists/*
WORKDIR /workspace
COPY pyproject.toml .
RUN pip install --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
RUN pip install --no-cache-dir torch-geometric
COPY . .
RUN pip install --no-cache-dir -e ".[train,dev]"
CMD ["pytest", "tests/", "-v"]

48
Makefile Normal file
View File

@@ -0,0 +1,48 @@
.PHONY: train test lint data-gen export format type-check install dev clean help
PYTHON ?= python
PYTEST ?= pytest
RUFF ?= ruff
MYPY ?= mypy
help: ## Show this help
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | \
awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-20s\033[0m %s\n", $$1, $$2}'
install: ## Install core dependencies
pip install -e .
dev: ## Install all dependencies including dev tools
pip install -e ".[train,dev]"
pre-commit install
pre-commit install --hook-type commit-msg
train: ## Run training (pass CONFIG=path/to/config.yaml)
$(PYTHON) -m solver.training.train $(if $(CONFIG),--config-path $(CONFIG))
test: ## Run test suite
$(PYTEST) tests/ freecad/tests/ -v --tb=short
lint: ## Run ruff linter
$(RUFF) check solver/ freecad/ tests/ scripts/
format: ## Format code with ruff
$(RUFF) format solver/ freecad/ tests/ scripts/
$(RUFF) check --fix solver/ freecad/ tests/ scripts/
type-check: ## Run mypy type checker
$(MYPY) solver/ freecad/
data-gen: ## Generate synthetic dataset (pass CONFIG=path/to/config.yaml)
$(PYTHON) scripts/generate_synthetic.py $(if $(CONFIG),--config-path $(CONFIG))
export: ## Export trained model for deployment
$(PYTHON) export/package_model.py $(if $(MODEL),--model $(MODEL))
clean: ## Remove build artifacts and caches
rm -rf build/ dist/ *.egg-info/
rm -rf .mypy_cache/ .pytest_cache/ .ruff_cache/
find . -type d -name __pycache__ -exec rm -rf {} + 2>/dev/null || true
find . -type f -name "*.pyc" -delete 2>/dev/null || true
check: lint type-check test ## Run all checks (lint, type-check, test)

View File

@@ -112,7 +112,7 @@ bool NewtonRaphson::isConvergedToNumericalLimit()
size_t nDivergenceMax = 3;
auto dxNormIterNo = dxNorms->at(iterNo);
if (iterNo > 0) {
auto dxNormIterNoOld = dxNorms->at(iterNo);
auto dxNormIterNoOld = dxNorms->at(iterNo - 1);
auto farTooLargeError = dxNormIterNo > tooLargeTol;
auto worthIterating = dxNormIterNo > (smallEnoughTol * pow(10.0, (iterNo / iterMax) * nDecade));
bool stillConverging;

View File

@@ -1,7 +1,91 @@
# MbDCode
Assembly Constraints and Multibody Dynamics code
# Kindred Solver
Install freecad9a.exe from ar-cad.com. Run program and read Explain menu items for documentations. (edited)
Assembly constraint solver for [Kindred Create](https://git.kindred-systems.com/kindred/create). Combines a numerical multibody dynamics engine (OndselSolver) with a GNN-based constraint prediction layer.
The MbD theory is at
https://github.com/Ondsel-Development/MbDTheory
## Components
### OndselSolver (C++)
Numerical assembly constraint solver using multibody dynamics. Solves joint constraints between rigid bodies using a Newton-Raphson iterative approach. Used by FreeCAD's Assembly workbench as the backend solver.
- Source: `OndselSolver/`
- Entry point: `OndselSolverMain/`
- Tests: `tests/`, `testapp/`
- Build: CMake
**Theory:** [MbDTheory](https://github.com/Ondsel-Development/MbDTheory)
#### Building
```bash
cmake -B build -DCMAKE_BUILD_TYPE=Release
cmake --build build
```
### ML Solver Layer (Python)
Graph neural network that predicts constraint independence and per-body degrees of freedom. Trained on synthetic assembly data generated via the pebble game algorithm, with the goal of augmenting or replacing the numerical solver for common assembly patterns.
- Core library: `solver/`
- Data generation: `solver/datagen/` (pebble game, synthetic assemblies, labeling)
- Model architectures: `solver/models/` (GIN, GAT, NNConv)
- Training: `solver/training/`
- Inference: `solver/inference/`
- FreeCAD integration: `freecad/`
- Configuration: `configs/` (Hydra)
#### Setup
```bash
pip install -e ".[train,dev]"
pre-commit install
```
#### Usage
```bash
make help # show all targets
make dev # install all deps + pre-commit hooks
make test # run tests
make lint # run ruff linter
make check # lint + type-check + test
make data-gen # generate synthetic data
make train # run training
make export # export model
```
Docker is also supported:
```bash
docker compose up train # GPU training
docker compose up test # run tests
docker compose up data-gen # generate synthetic data
```
## Repository structure
```
kindred-solver/
├── OndselSolver/ # C++ numerical solver library
├── OndselSolverMain/ # C++ solver CLI entry point
├── tests/ # C++ unit tests + Python tests
├── testapp/ # C++ test application
├── solver/ # Python ML solver library
│ ├── datagen/ # Synthetic data generation (pebble game)
│ ├── datasets/ # PyG dataset adapters
│ ├── models/ # GNN architectures
│ ├── training/ # Training loops
│ ├── evaluation/ # Metrics and visualization
│ └── inference/ # Runtime prediction API
├── freecad/ # FreeCAD workbench integration
├── configs/ # Hydra configs (dataset, model, training, export)
├── scripts/ # CLI utilities
├── data/ # Datasets (not committed)
├── export/ # Model packaging
└── docs/ # Documentation
```
## License
OndselSolver: LGPL-2.1-or-later (see [LICENSE](LICENSE))
ML Solver Layer: Apache-2.0

View File

@@ -0,0 +1,12 @@
# Fusion 360 Gallery dataset config
name: fusion360
data_dir: data/fusion360
output_dir: data/processed
splits:
train: 0.8
val: 0.1
test: 0.1
stratify_by: complexity
seed: 42

View File

@@ -0,0 +1,26 @@
# Synthetic dataset generation config
name: synthetic
num_assemblies: 100000
output_dir: data/synthetic
shard_size: 1000
complexity_distribution:
simple: 0.4 # 2-5 bodies
medium: 0.4 # 6-15 bodies
complex: 0.2 # 16-50 bodies
body_count:
min: 2
max: 50
templates:
- chain
- tree
- loop
- star
- mixed
grounded_ratio: 0.5
seed: 42
num_workers: 4
checkpoint_every: 5

View File

@@ -0,0 +1,25 @@
# Production model export config
model_checkpoint: checkpoints/finetune/best_val_loss.ckpt
output_dir: export/
formats:
onnx:
enabled: true
opset_version: 17
dynamic_axes: true
torchscript:
enabled: true
model_card:
version: "0.1.0"
architecture: baseline
training_data:
- synthetic_100k
- fusion360_gallery
size_budget_mb: 50
inference:
device: cpu
batch_size: 1
confidence_threshold: 0.8

View File

@@ -0,0 +1,24 @@
# Baseline GIN model config
name: baseline
architecture: gin
encoder:
num_layers: 3
hidden_dim: 128
dropout: 0.1
node_features_dim: 22
edge_features_dim: 22
heads:
edge_classification:
enabled: true
hidden_dim: 64
graph_classification:
enabled: true
num_classes: 4 # rigid, under, over, mixed
joint_type:
enabled: true
num_classes: 12
dof_regression:
enabled: true

28
configs/model/gat.yaml Normal file
View File

@@ -0,0 +1,28 @@
# Advanced GAT model config
name: gat_solver
architecture: gat
encoder:
num_layers: 4
hidden_dim: 256
num_heads: 8
dropout: 0.1
residual: true
node_features_dim: 22
edge_features_dim: 22
heads:
edge_classification:
enabled: true
hidden_dim: 128
graph_classification:
enabled: true
num_classes: 4
joint_type:
enabled: true
num_classes: 12
dof_regression:
enabled: true
dof_tracking:
enabled: true

View File

@@ -0,0 +1,45 @@
# Fine-tuning on real data config
phase: finetune
dataset: fusion360
model: baseline
pretrained_checkpoint: checkpoints/pretrain/best_val_loss.ckpt
optimizer:
name: adamw
lr: 1e-5
weight_decay: 1e-4
scheduler:
name: cosine_annealing
T_max: 50
eta_min: 1e-7
training:
epochs: 50
batch_size: 32
gradient_clip: 1.0
early_stopping_patience: 10
amp: true
freeze_encoder: false # set true for frozen encoder experiment
loss:
edge_weight: 1.0
graph_weight: 0.5
joint_type_weight: 0.3
dof_weight: 0.2
redundant_penalty: 2.0
checkpointing:
save_best_val_loss: true
save_best_val_accuracy: true
save_every_n_epochs: 5
checkpoint_dir: checkpoints/finetune
logging:
backend: wandb
project: kindred-solver
log_every_n_steps: 20
seed: 42

View File

@@ -0,0 +1,42 @@
# Synthetic pre-training config
phase: pretrain
dataset: synthetic
model: baseline
optimizer:
name: adamw
lr: 1e-3
weight_decay: 1e-4
scheduler:
name: cosine_annealing
T_max: 100
eta_min: 1e-6
training:
epochs: 100
batch_size: 64
gradient_clip: 1.0
early_stopping_patience: 10
amp: true
loss:
edge_weight: 1.0
graph_weight: 0.5
joint_type_weight: 0.3
dof_weight: 0.2
redundant_penalty: 2.0 # safety loss multiplier
checkpointing:
save_best_val_loss: true
save_best_val_accuracy: true
save_every_n_epochs: 10
checkpoint_dir: checkpoints/pretrain
logging:
backend: wandb # or tensorboard
project: kindred-solver
log_every_n_steps: 50
seed: 42

0
data/fusion360/.gitkeep Normal file
View File

0
data/processed/.gitkeep Normal file
View File

0
data/splits/.gitkeep Normal file
View File

0
data/synthetic/.gitkeep Normal file
View File

39
docker-compose.yml Normal file
View File

@@ -0,0 +1,39 @@
services:
train:
build:
context: .
dockerfile: Dockerfile
target: base
volumes:
- .:/workspace
- ./data:/workspace/data
deploy:
resources:
reservations:
devices:
- driver: nvidia
count: all
capabilities: [gpu]
command: make train
environment:
- CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-0}
- WANDB_API_KEY=${WANDB_API_KEY:-}
test:
build:
context: .
dockerfile: Dockerfile
target: cpu-only
volumes:
- .:/workspace
command: make check
data-gen:
build:
context: .
dockerfile: Dockerfile
target: base
volumes:
- .:/workspace
- ./data:/workspace/data
command: make data-gen

0
docs/.gitkeep Normal file
View File

0
export/.gitkeep Normal file
View File

0
freecad/__init__.py Normal file
View File

View File

View File

View File

97
pyproject.toml Normal file
View File

@@ -0,0 +1,97 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[project]
name = "kindred-solver"
version = "0.1.0"
description = "Assembly constraint prediction via GNN for Kindred Create"
readme = "README.md"
license = "Apache-2.0"
requires-python = ">=3.11"
authors = [
{ name = "Kindred Systems" },
]
classifiers = [
"Development Status :: 3 - Alpha",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Topic :: Scientific/Engineering",
]
dependencies = [
"torch>=2.2",
"torch-geometric>=2.5",
"numpy>=1.26",
"scipy>=1.12",
]
[project.optional-dependencies]
train = [
"wandb>=0.16",
"tensorboard>=2.16",
"hydra-core>=1.3",
"omegaconf>=2.3",
"matplotlib>=3.8",
"networkx>=3.2",
]
freecad = [
"pyside6>=6.6",
]
dev = [
"pytest>=8.0",
"pytest-cov>=4.1",
"ruff>=0.3",
"mypy>=1.8",
"pre-commit>=3.6",
]
[project.urls]
Repository = "https://git.kindred-systems.com/kindred/solver"
[tool.hatch.build.targets.wheel]
packages = ["solver", "freecad"]
[tool.ruff]
target-version = "py311"
line-length = 100
[tool.ruff.lint]
select = [
"E", # pycodestyle errors
"W", # pycodestyle warnings
"F", # pyflakes
"I", # isort
"N", # pep8-naming
"UP", # pyupgrade
"B", # flake8-bugbear
"SIM", # flake8-simplify
"TCH", # flake8-type-checking
"RUF", # ruff-specific
]
[tool.ruff.lint.isort]
known-first-party = ["solver", "freecad"]
[tool.mypy]
python_version = "3.11"
warn_return_any = true
warn_unused_configs = true
disallow_untyped_defs = true
check_untyped_defs = true
[[tool.mypy.overrides]]
module = [
"torch.*",
"torch_geometric.*",
"scipy.*",
"wandb.*",
"hydra.*",
"omegaconf.*",
]
ignore_missing_imports = true
[tool.pytest.ini_options]
testpaths = ["tests", "freecad/tests"]
addopts = "-v --tb=short"

View File

@@ -0,0 +1,115 @@
#!/usr/bin/env python3
"""Generate synthetic assembly dataset for kindred-solver training.
Usage (argparse fallback — always available)::
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
Usage (Hydra — when hydra-core is installed)::
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
"""
from __future__ import annotations
def _try_hydra_main() -> bool:
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
try:
import hydra # type: ignore[import-untyped]
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
except ImportError:
return False
@hydra.main(
config_path="../configs/dataset",
config_name="synthetic",
version_base=None,
)
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
config_dict = OmegaConf.to_container(cfg, resolve=True)
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
_run() # type: ignore[no-untyped-call]
return True
def _argparse_main() -> None:
"""Fallback CLI using argparse."""
import argparse
parser = argparse.ArgumentParser(
description="Generate synthetic assembly dataset",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to YAML config file (optional)",
)
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
parser.add_argument(
"--checkpoint-every",
type=int,
default=None,
help="Checkpoint interval (shards)",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Do not resume from existing checkpoints",
)
args = parser.parse_args()
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
parse_simple_yaml,
)
config_dict: dict[str, object] = {}
if args.config:
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
# CLI args override config file (only when explicitly provided)
_override_map = {
"num_assemblies": args.num_assemblies,
"output_dir": args.output_dir,
"shard_size": args.shard_size,
"body_count_min": args.body_count_min,
"body_count_max": args.body_count_max,
"grounded_ratio": args.grounded_ratio,
"seed": args.seed,
"num_workers": args.num_workers,
"checkpoint_every": args.checkpoint_every,
}
for key, val in _override_map.items():
if val is not None:
config_dict[key] = val
if args.no_resume:
config_dict["resume"] = False
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
def main() -> None:
"""Entry point: try Hydra first, fall back to argparse."""
if not _try_hydra_main():
_argparse_main()
if __name__ == "__main__":
main()

0
solver/__init__.py Normal file
View File

View File

@@ -0,0 +1,37 @@
"""Data generation utilities for assembly constraint training data."""
from solver.datagen.analysis import analyze_assembly
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
from solver.datagen.generator import (
COMPLEXITY_RANGES,
AxisStrategy,
SyntheticAssemblyGenerator,
)
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.labeling import AssemblyLabels, label_assembly
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
PebbleState,
RigidBody,
)
__all__ = [
"COMPLEXITY_RANGES",
"AssemblyLabels",
"AxisStrategy",
"ConstraintAnalysis",
"DatasetConfig",
"DatasetGenerator",
"JacobianVerifier",
"Joint",
"JointType",
"PebbleGame3D",
"PebbleState",
"RigidBody",
"SyntheticAssemblyGenerator",
"analyze_assembly",
"label_assembly",
]

140
solver/datagen/analysis.py Normal file
View File

@@ -0,0 +1,140 @@
"""Combined pebble game + Jacobian verification analysis.
Provides :func:`analyze_assembly`, the main entry point for full rigidity
analysis of an assembly using both combinatorial and numerical methods.
"""
from __future__ import annotations
import numpy as np
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
__all__ = ["analyze_assembly"]
_GROUND_ID = -1
def analyze_assembly(
bodies: list[RigidBody],
joints: list[Joint],
ground_body: int | None = None,
) -> ConstraintAnalysis:
"""Full rigidity analysis of an assembly using both methods.
Args:
bodies: List of rigid bodies in the assembly.
joints: List of joints connecting bodies.
ground_body: If set, this body is fixed (adds 6 implicit constraints).
Returns:
ConstraintAnalysis with combinatorial and numerical results.
"""
# --- Pebble Game ---
pg = PebbleGame3D()
all_edge_results = []
# Add a virtual ground body (id=-1) if grounding is requested.
# Grounding body X means adding a fixed joint between X and
# the virtual ground. This properly lets the pebble game account
# for the 6 removed DOF without breaking invariants.
if ground_body is not None:
pg.add_body(_GROUND_ID)
for body in bodies:
pg.add_body(body.body_id)
if ground_body is not None:
ground_joint = Joint(
joint_id=-1,
body_a=ground_body,
body_b=_GROUND_ID,
joint_type=JointType.FIXED,
anchor_a=bodies[0].position if bodies else np.zeros(3),
anchor_b=bodies[0].position if bodies else np.zeros(3),
)
pg.add_joint(ground_joint)
# Don't include ground joint edges in the output labels
# (they're infrastructure, not user constraints)
for joint in joints:
results = pg.add_joint(joint)
all_edge_results.extend(results)
combinatorial_independent = len(pg.state.independent_edges)
grounded = ground_body is not None
# The virtual ground body contributes 6 pebbles to the total.
# Subtract those from the reported DOF for user-facing numbers.
raw_dof = pg.get_dof()
ground_offset = 6 if grounded else 0
effective_dof = raw_dof - ground_offset
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
# Classify based on effective (adjusted) DOF, not raw pebble game output,
# because the virtual ground body skews the raw numbers.
redundant = pg.get_redundant_count()
if redundant > 0 and effective_internal_dof > 0:
combinatorial_classification = "mixed"
elif redundant > 0:
combinatorial_classification = "overconstrained"
elif effective_internal_dof > 0:
combinatorial_classification = "underconstrained"
else:
combinatorial_classification = "well-constrained"
# --- Jacobian Verification ---
verifier = JacobianVerifier(bodies)
for joint in joints:
verifier.add_joint_constraints(joint)
# If grounded, remove the ground body's columns (fix its DOF)
j = verifier.get_jacobian()
if ground_body is not None and j.size > 0:
idx = verifier.body_index[ground_body]
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
j = np.delete(j, cols_to_remove, axis=1)
if j.size > 0:
sv = np.linalg.svd(j, compute_uv=False)
jacobian_rank = int(np.sum(sv > 1e-8))
else:
jacobian_rank = 0
n_cols = j.shape[1] if j.size > 0 else 6 * len(bodies)
jacobian_nullity = n_cols - jacobian_rank
dependent = verifier.find_dependencies()
# Adjust for ground
trivial_dof = 0 if ground_body is not None else 6
jacobian_internal_dof = jacobian_nullity - trivial_dof
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
# Rigidity: numerically rigid if nullity == trivial DOF
is_rigid = jacobian_nullity <= trivial_dof
is_minimally_rigid = is_rigid and len(dependent) == 0
return ConstraintAnalysis(
combinatorial_dof=effective_dof,
combinatorial_internal_dof=effective_internal_dof,
combinatorial_redundant=pg.get_redundant_count(),
combinatorial_classification=combinatorial_classification,
per_edge_results=all_edge_results,
jacobian_rank=jacobian_rank,
jacobian_nullity=jacobian_nullity,
jacobian_internal_dof=max(0, jacobian_internal_dof),
numerically_dependent=dependent,
geometric_degeneracies=geometric_degeneracies,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
)

624
solver/datagen/dataset.py Normal file
View File

@@ -0,0 +1,624 @@
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
for parallel generation of synthetic assembly training data.
"""
from __future__ import annotations
import hashlib
import json
import logging
import math
import sys
import time
from concurrent.futures import ProcessPoolExecutor, as_completed
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"DatasetConfig",
"DatasetGenerator",
"parse_simple_yaml",
]
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
@dataclass
class DatasetConfig:
"""Configuration for synthetic dataset generation."""
name: str = "synthetic"
num_assemblies: int = 100_000
output_dir: str = "data/synthetic"
shard_size: int = 1000
complexity_distribution: dict[str, float] = field(
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
)
body_count_min: int = 2
body_count_max: int = 50
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
grounded_ratio: float = 0.5
seed: int = 42
num_workers: int = 4
checkpoint_every: int = 5
resume: bool = True
@classmethod
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
Handles both flat keys (``body_count_min``) and nested forms
(``body_count: {min: 2, max: 50}``).
"""
kw: dict[str, Any] = {}
for key in (
"name",
"num_assemblies",
"output_dir",
"shard_size",
"grounded_ratio",
"seed",
"num_workers",
"checkpoint_every",
"resume",
):
if key in d:
kw[key] = d[key]
# Handle nested body_count dict
if "body_count" in d and isinstance(d["body_count"], dict):
bc = d["body_count"]
if "min" in bc:
kw["body_count_min"] = int(bc["min"])
if "max" in bc:
kw["body_count_max"] = int(bc["max"])
else:
if "body_count_min" in d:
kw["body_count_min"] = int(d["body_count_min"])
if "body_count_max" in d:
kw["body_count_max"] = int(d["body_count_max"])
if "complexity_distribution" in d:
cd = d["complexity_distribution"]
if isinstance(cd, dict):
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
if "templates" in d and isinstance(d["templates"], list):
kw["templates"] = [str(t) for t in d["templates"]]
return cls(**kw)
# ---------------------------------------------------------------------------
# Shard specification / result
# ---------------------------------------------------------------------------
@dataclass
class ShardSpec:
"""Specification for generating a single shard."""
shard_id: int
start_example_id: int
count: int
seed: int
complexity_distribution: dict[str, float]
body_count_min: int
body_count_max: int
grounded_ratio: float
@dataclass
class ShardResult:
"""Result returned from a shard worker."""
shard_id: int
num_examples: int
file_path: str
generation_time_s: float
# ---------------------------------------------------------------------------
# Seed derivation
# ---------------------------------------------------------------------------
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
"""Derive a deterministic per-shard seed from the global seed."""
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
return int(h[:8], 16)
# ---------------------------------------------------------------------------
# Progress display
# ---------------------------------------------------------------------------
class _PrintProgress:
"""Fallback progress display when tqdm is unavailable."""
def __init__(self, total: int) -> None:
self.total = total
self.current = 0
self.start_time = time.monotonic()
def update(self, n: int = 1) -> None:
self.current += n
elapsed = time.monotonic() - self.start_time
rate = self.current / elapsed if elapsed > 0 else 0.0
eta = (self.total - self.current) / rate if rate > 0 else 0.0
pct = 100.0 * self.current / self.total
sys.stdout.write(
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
)
sys.stdout.flush()
def close(self) -> None:
sys.stdout.write("\n")
sys.stdout.flush()
def _make_progress(total: int) -> _PrintProgress:
"""Create a progress tracker (tqdm if available, else print-based)."""
try:
from tqdm import tqdm # type: ignore[import-untyped]
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
except ImportError:
return _PrintProgress(total)
# ---------------------------------------------------------------------------
# Shard I/O
# ---------------------------------------------------------------------------
def _save_shard(
shard_id: int,
examples: list[dict[str, Any]],
shards_dir: Path,
) -> Path:
"""Save a shard to disk (.pt if torch available, else .json)."""
shards_dir.mkdir(parents=True, exist_ok=True)
try:
import torch # type: ignore[import-untyped]
path = shards_dir / f"shard_{shard_id:05d}.pt"
torch.save(examples, path)
except ImportError:
path = shards_dir / f"shard_{shard_id:05d}.json"
with open(path, "w") as f:
json.dump(examples, f)
return path
def _load_shard(path: Path) -> list[dict[str, Any]]:
"""Load a shard from disk (.pt or .json)."""
if path.suffix == ".pt":
import torch # type: ignore[import-untyped]
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
return result
with open(path) as f:
result = json.load(f)
return result
def _shard_format() -> str:
"""Return the shard file extension based on available libraries."""
try:
import torch # type: ignore[import-untyped] # noqa: F401
return ".pt"
except ImportError:
return ".json"
# ---------------------------------------------------------------------------
# Shard worker (module-level for pickling)
# ---------------------------------------------------------------------------
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
from solver.datagen.generator import SyntheticAssemblyGenerator
t0 = time.monotonic()
gen = SyntheticAssemblyGenerator(seed=spec.seed)
rng = np.random.default_rng(spec.seed + 1)
tiers = list(spec.complexity_distribution.keys())
probs_list = [spec.complexity_distribution[t] for t in tiers]
total = sum(probs_list)
probs = [p / total for p in probs_list]
examples: list[dict[str, Any]] = []
for i in range(spec.count):
tier_idx = int(rng.choice(len(tiers), p=probs))
tier = tiers[tier_idx]
try:
batch = gen.generate_training_batch(
batch_size=1,
complexity_tier=tier, # type: ignore[arg-type]
grounded_ratio=spec.grounded_ratio,
)
ex = batch[0]
ex["example_id"] = spec.start_example_id + i
ex["complexity_tier"] = tier
examples.append(ex)
except Exception:
logger.warning(
"Shard %d, example %d failed — skipping",
spec.shard_id,
i,
exc_info=True,
)
shards_dir = Path(output_dir) / "shards"
path = _save_shard(spec.shard_id, examples, shards_dir)
elapsed = time.monotonic() - t0
return ShardResult(
shard_id=spec.shard_id,
num_examples=len(examples),
file_path=str(path),
generation_time_s=elapsed,
)
# ---------------------------------------------------------------------------
# Minimal YAML parser
# ---------------------------------------------------------------------------
def _parse_scalar(value: str) -> int | float | bool | str:
"""Parse a YAML scalar value."""
# Strip inline comments (space + #)
if " #" in value:
value = value[: value.index(" #")].strip()
elif " #" in value:
value = value[: value.index(" #")].strip()
v = value.strip()
if v.lower() in ("true", "yes"):
return True
if v.lower() in ("false", "no"):
return False
try:
return int(v)
except ValueError:
pass
try:
return float(v)
except ValueError:
pass
return v.strip("'\"")
def parse_simple_yaml(path: str) -> dict[str, Any]:
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
This is **not** a full YAML parser. It handles the structure of
``configs/dataset/synthetic.yaml``.
"""
result: dict[str, Any] = {}
current_key: str | None = None
with open(path) as f:
for raw_line in f:
line = raw_line.rstrip()
# Skip blank lines and full-line comments
if not line or line.lstrip().startswith("#"):
continue
indent = len(line) - len(line.lstrip())
if indent == 0 and ":" in line:
key, _, value = line.partition(":")
key = key.strip()
value = value.strip()
if value:
result[key] = _parse_scalar(value)
current_key = None
else:
current_key = key
result[key] = {}
continue
if indent > 0 and line.lstrip().startswith("- "):
item = line.lstrip()[2:].strip()
if current_key is not None:
if isinstance(result.get(current_key), dict) and not result[current_key]:
result[current_key] = []
if isinstance(result.get(current_key), list):
result[current_key].append(_parse_scalar(item))
continue
if indent > 0 and ":" in line and current_key is not None:
k, _, v = line.partition(":")
k = k.strip()
v = v.strip()
if v:
# Strip inline comments
if " #" in v:
v = v[: v.index(" #")].strip()
if not isinstance(result.get(current_key), dict):
result[current_key] = {}
result[current_key][k] = _parse_scalar(v)
continue
return result
# ---------------------------------------------------------------------------
# Dataset generator orchestrator
# ---------------------------------------------------------------------------
class DatasetGenerator:
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
def __init__(self, config: DatasetConfig) -> None:
self.config = config
self.output_path = Path(config.output_dir)
self.shards_dir = self.output_path / "shards"
self.checkpoint_file = self.output_path / ".checkpoint.json"
self.index_file = self.output_path / "index.json"
self.stats_file = self.output_path / "stats.json"
# -- public API --
def run(self) -> None:
"""Generate the full dataset."""
self.output_path.mkdir(parents=True, exist_ok=True)
self.shards_dir.mkdir(parents=True, exist_ok=True)
shards = self._plan_shards()
total_shards = len(shards)
# Resume: find already-completed shards
completed: set[int] = set()
if self.config.resume:
completed = self._find_completed_shards()
pending = [s for s in shards if s.shard_id not in completed]
if not pending:
logger.info("All %d shards already complete.", total_shards)
else:
logger.info(
"Generating %d shards (%d already complete).",
len(pending),
len(completed),
)
progress = _make_progress(len(pending))
workers = max(1, self.config.num_workers)
checkpoint_counter = 0
with ProcessPoolExecutor(max_workers=workers) as pool:
futures = {
pool.submit(
_generate_shard_worker,
spec,
str(self.output_path),
): spec.shard_id
for spec in pending
}
for future in as_completed(futures):
shard_id = futures[future]
try:
result = future.result()
completed.add(result.shard_id)
logger.debug(
"Shard %d: %d examples in %.1fs",
result.shard_id,
result.num_examples,
result.generation_time_s,
)
except Exception:
logger.error("Shard %d failed", shard_id, exc_info=True)
progress.update(1)
checkpoint_counter += 1
if checkpoint_counter >= self.config.checkpoint_every:
self._update_checkpoint(completed, total_shards)
checkpoint_counter = 0
progress.close()
# Finalize
self._build_index()
stats = self._compute_statistics()
self._write_statistics(stats)
self._print_summary(stats)
# Remove checkpoint (generation complete)
if self.checkpoint_file.exists():
self.checkpoint_file.unlink()
# -- internal helpers --
def _plan_shards(self) -> list[ShardSpec]:
"""Divide num_assemblies into shards."""
n = self.config.num_assemblies
size = self.config.shard_size
num_shards = math.ceil(n / size)
shards: list[ShardSpec] = []
for i in range(num_shards):
start = i * size
count = min(size, n - start)
shards.append(
ShardSpec(
shard_id=i,
start_example_id=start,
count=count,
seed=_derive_shard_seed(self.config.seed, i),
complexity_distribution=dict(self.config.complexity_distribution),
body_count_min=self.config.body_count_min,
body_count_max=self.config.body_count_max,
grounded_ratio=self.config.grounded_ratio,
)
)
return shards
def _find_completed_shards(self) -> set[int]:
"""Scan shards directory for existing shard files."""
completed: set[int] = set()
if not self.shards_dir.exists():
return completed
for p in self.shards_dir.iterdir():
if p.stem.startswith("shard_"):
try:
shard_id = int(p.stem.split("_")[1])
# Verify file is non-empty
if p.stat().st_size > 0:
completed.add(shard_id)
except (ValueError, IndexError):
pass
return completed
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
"""Write checkpoint file."""
data = {
"completed_shards": sorted(completed),
"total_shards": total_shards,
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
}
with open(self.checkpoint_file, "w") as f:
json.dump(data, f)
def _build_index(self) -> None:
"""Build index.json mapping shard files to assembly ID ranges."""
shards_info: dict[str, dict[str, int]] = {}
total_assemblies = 0
for p in sorted(self.shards_dir.iterdir()):
if not p.stem.startswith("shard_"):
continue
try:
shard_id = int(p.stem.split("_")[1])
except (ValueError, IndexError):
continue
examples = _load_shard(p)
count = len(examples)
start_id = shard_id * self.config.shard_size
shards_info[p.name] = {"start_id": start_id, "count": count}
total_assemblies += count
fmt = _shard_format().lstrip(".")
index = {
"format_version": 1,
"total_assemblies": total_assemblies,
"total_shards": len(shards_info),
"shard_format": fmt,
"shards": shards_info,
}
with open(self.index_file, "w") as f:
json.dump(index, f, indent=2)
def _compute_statistics(self) -> dict[str, Any]:
"""Aggregate statistics across all shards."""
classification_counts: dict[str, int] = {}
body_count_hist: dict[int, int] = {}
joint_type_counts: dict[str, int] = {}
dof_values: list[int] = []
degeneracy_values: list[int] = []
rigid_count = 0
minimally_rigid_count = 0
total = 0
for p in sorted(self.shards_dir.iterdir()):
if not p.stem.startswith("shard_"):
continue
examples = _load_shard(p)
for ex in examples:
total += 1
cls = str(ex.get("assembly_classification", "unknown"))
classification_counts[cls] = classification_counts.get(cls, 0) + 1
nb = int(ex.get("n_bodies", 0))
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
for j in ex.get("joints", []):
jt = str(j.get("type", "unknown"))
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
dof_values.append(int(ex.get("internal_dof", 0)))
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
if ex.get("is_rigid"):
rigid_count += 1
if ex.get("is_minimally_rigid"):
minimally_rigid_count += 1
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
return {
"total_examples": total,
"classification_distribution": dict(sorted(classification_counts.items())),
"body_count_histogram": dict(sorted(body_count_hist.items())),
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
"dof_statistics": {
"mean": float(dof_arr.mean()),
"std": float(dof_arr.std()),
"min": int(dof_arr.min()),
"max": int(dof_arr.max()),
"median": float(np.median(dof_arr)),
},
"geometric_degeneracy": {
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
"mean_degeneracies": float(deg_arr.mean()),
},
"rigidity": {
"rigid_count": rigid_count,
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
"minimally_rigid_count": minimally_rigid_count,
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
},
}
def _write_statistics(self, stats: dict[str, Any]) -> None:
"""Write stats.json."""
with open(self.stats_file, "w") as f:
json.dump(stats, f, indent=2)
def _print_summary(self, stats: dict[str, Any]) -> None:
"""Print a human-readable summary to stdout."""
print("\n=== Dataset Generation Summary ===")
print(f"Total examples: {stats['total_examples']}")
print(f"Output directory: {self.output_path}")
print()
print("Classification distribution:")
for cls, count in stats["classification_distribution"].items():
frac = count / max(stats["total_examples"], 1) * 100
print(f" {cls}: {count} ({frac:.1f}%)")
print()
print("Joint type distribution:")
for jt, count in stats["joint_type_distribution"].items():
print(f" {jt}: {count}")
print()
dof = stats["dof_statistics"]
print(
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
)
rig = stats["rigidity"]
print(
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
f"{rig['minimally_rigid_count']} minimally rigid"
)
deg = stats["geometric_degeneracy"]
print(
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
)

893
solver/datagen/generator.py Normal file
View File

@@ -0,0 +1,893 @@
"""Synthetic assembly graph generator for training data production.
Generates assembly graphs with known constraint classifications using
the pebble game and Jacobian verification. Each assembly is fully labeled
with per-constraint independence flags and assembly-level classification.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Literal
import numpy as np
from scipy.spatial.transform import Rotation
from solver.datagen.analysis import analyze_assembly
from solver.datagen.labeling import label_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
if TYPE_CHECKING:
from typing import Any
__all__ = [
"COMPLEXITY_RANGES",
"AxisStrategy",
"ComplexityTier",
"SyntheticAssemblyGenerator",
]
# ---------------------------------------------------------------------------
# Complexity tiers — ranges use exclusive upper bound for rng.integers()
# ---------------------------------------------------------------------------
ComplexityTier = Literal["simple", "medium", "complex"]
COMPLEXITY_RANGES: dict[str, tuple[int, int]] = {
"simple": (2, 6),
"medium": (6, 16),
"complex": (16, 51),
}
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
AxisStrategy = Literal["cardinal", "random", "near_parallel"]
class SyntheticAssemblyGenerator:
"""Generates assembly graphs with known minimal constraint sets.
Uses the pebble game to incrementally build assemblies, tracking
exactly which constraints are independent at each step. This produces
labeled training data: (assembly_graph, constraint_set, labels).
Labels per constraint:
- independent: bool (does this constraint remove a DOF?)
- redundant: bool (is this constraint overconstrained?)
- minimal_set: bool (part of a minimal rigidity basis?)
"""
def __init__(self, seed: int = 42) -> None:
self.rng = np.random.default_rng(seed)
# ------------------------------------------------------------------
# Private helpers
# ------------------------------------------------------------------
def _random_position(self, scale: float = 5.0) -> np.ndarray:
"""Generate random 3D position within [-scale, scale] cube."""
return self.rng.uniform(-scale, scale, size=3)
def _random_axis(self) -> np.ndarray:
"""Generate random normalized 3D axis."""
axis = self.rng.standard_normal(3)
axis /= np.linalg.norm(axis)
return axis
def _random_orientation(self) -> np.ndarray:
"""Generate a random 3x3 rotation matrix."""
mat: np.ndarray = Rotation.random(random_state=self.rng).as_matrix()
return mat
def _cardinal_axis(self) -> np.ndarray:
"""Pick uniformly from the six signed cardinal directions."""
axes = np.array(
[
[1, 0, 0],
[-1, 0, 0],
[0, 1, 0],
[0, -1, 0],
[0, 0, 1],
[0, 0, -1],
],
dtype=float,
)
result: np.ndarray = axes[int(self.rng.integers(6))]
return result
def _near_parallel_axis(
self,
base_axis: np.ndarray,
perturbation_scale: float = 0.05,
) -> np.ndarray:
"""Return *base_axis* with a small random perturbation, re-normalized."""
perturbed = base_axis + self.rng.standard_normal(3) * perturbation_scale
return perturbed / np.linalg.norm(perturbed)
def _sample_axis(self, strategy: AxisStrategy = "random") -> np.ndarray:
"""Sample a joint axis using the specified strategy."""
if strategy == "cardinal":
return self._cardinal_axis()
if strategy == "near_parallel":
return self._near_parallel_axis(np.array([0.0, 0.0, 1.0]))
return self._random_axis()
def _resolve_axis(
self,
strategy: AxisStrategy,
parallel_axis_prob: float,
shared_axis: np.ndarray | None,
) -> tuple[np.ndarray, np.ndarray | None]:
"""Return (axis_for_this_joint, shared_axis_to_propagate).
On the first call where *shared_axis* is ``None`` and parallel
injection triggers, a base axis is chosen and returned as
*shared_axis* for subsequent calls.
"""
if shared_axis is not None:
return self._near_parallel_axis(shared_axis), shared_axis
if parallel_axis_prob > 0 and self.rng.random() < parallel_axis_prob:
base = self._sample_axis(strategy)
return base.copy(), base
return self._sample_axis(strategy), None
def _select_joint_type(
self,
joint_types: JointType | list[JointType],
) -> JointType:
"""Select a joint type from a single type or list."""
if isinstance(joint_types, list):
idx = int(self.rng.integers(0, len(joint_types)))
return joint_types[idx]
return joint_types
def _create_joint(
self,
joint_id: int,
body_a_id: int,
body_b_id: int,
pos_a: np.ndarray,
pos_b: np.ndarray,
joint_type: JointType,
*,
axis: np.ndarray | None = None,
orient_a: np.ndarray | None = None,
orient_b: np.ndarray | None = None,
) -> Joint:
"""Create a joint between two bodies.
When orientations are provided, anchor points are offset from
each body's center along a random local direction rotated into
world frame, rather than placed at the midpoint.
"""
if orient_a is not None and orient_b is not None:
dist = max(float(np.linalg.norm(pos_b - pos_a)), 0.1)
offset_scale = dist * 0.2
local_a = self.rng.standard_normal(3) * offset_scale
local_b = self.rng.standard_normal(3) * offset_scale
anchor_a = pos_a + orient_a @ local_a
anchor_b = pos_b + orient_b @ local_b
else:
anchor = (pos_a + pos_b) / 2.0
anchor_a = anchor
anchor_b = anchor
return Joint(
joint_id=joint_id,
body_a=body_a_id,
body_b=body_b_id,
joint_type=joint_type,
anchor_a=anchor_a,
anchor_b=anchor_b,
axis=axis if axis is not None else self._random_axis(),
)
# ------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ------------------------------------------------------------------
def generate_chain_assembly(
self,
n_bodies: int,
joint_type: JointType = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a serial kinematic chain.
Simple but useful: each body connects to the next with the
specified joint type. Results in an underconstrained assembly
(serial chain is never rigid without closing loops).
"""
bodies = []
joints = []
for i in range(n_bodies):
pos = np.array([i * 2.0, 0.0, 0.0])
bodies.append(
RigidBody(
body_id=i,
position=pos,
orientation=self._random_orientation(),
)
)
shared_axis: np.ndarray | None = None
for i in range(n_bodies - 1):
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i,
i,
i + 1,
bodies[i].position,
bodies[i + 1].position,
joint_type,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[i + 1].orientation,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_rigid_assembly(
self,
n_bodies: int,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a minimally rigid assembly by adding joints until rigid.
Strategy: start with fixed joints on a spanning tree (guarantees
rigidity), then randomly relax some to weaker joint types while
maintaining rigidity via the pebble game check.
"""
bodies = []
for i in range(n_bodies):
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Build spanning tree with fixed joints (overconstrained)
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i - 1,
parent,
i,
bodies[parent].position,
bodies[i].position,
JointType.FIXED,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
# Try relaxing joints to weaker types while maintaining rigidity
weaker_types = [
JointType.REVOLUTE,
JointType.CYLINDRICAL,
JointType.BALL,
]
ground = 0 if grounded else None
for idx in self.rng.permutation(len(joints)):
original_type = joints[idx].joint_type
for candidate in weaker_types:
joints[idx].joint_type = candidate
analysis = analyze_assembly(bodies, joints, ground_body=ground)
if analysis.is_rigid:
break # Keep the weaker type
else:
joints[idx].joint_type = original_type
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
def generate_overconstrained_assembly(
self,
n_bodies: int,
extra_joints: int = 2,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate an assembly with known redundant constraints.
Starts with a rigid assembly, then adds extra joints that
the pebble game will flag as redundant.
"""
bodies, joints, _ = self.generate_rigid_assembly(
n_bodies,
grounded=grounded,
axis_strategy=axis_strategy,
parallel_axis_prob=parallel_axis_prob,
)
joint_id = len(joints)
shared_axis: np.ndarray | None = None
for _ in range(extra_joints):
a, b = self.rng.choice(n_bodies, size=2, replace=False)
_overcon_types = [
JointType.REVOLUTE,
JointType.FIXED,
JointType.BALL,
]
jtype = _overcon_types[int(self.rng.integers(len(_overcon_types)))]
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
int(a),
int(b),
bodies[int(a)].position,
bodies[int(b)].position,
jtype,
axis=axis,
orient_a=bodies[int(a)].orientation,
orient_b=bodies[int(b)].orientation,
)
)
joint_id += 1
ground = 0 if grounded else None
analysis = analyze_assembly(bodies, joints, ground_body=ground)
return bodies, joints, analysis
# ------------------------------------------------------------------
# New topology generators
# ------------------------------------------------------------------
def generate_tree_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
branching_factor: int = 3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a random tree topology with configurable branching.
Creates a tree where each body can have up to *branching_factor*
children. Different branches can use different joint types if a
list is provided. Always underconstrained (no closed loops).
Args:
n_bodies: Total bodies (root + children).
joint_types: Single type or list to sample from per joint.
branching_factor: Max children per parent (1-5 recommended).
"""
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=self._random_orientation(),
)
]
joints: list[Joint] = []
available_parents = [0]
next_id = 1
joint_id = 0
shared_axis: np.ndarray | None = None
while next_id < n_bodies and available_parents:
pidx = int(self.rng.integers(0, len(available_parents)))
parent_id = available_parents[pidx]
parent_pos = bodies[parent_id].position
max_children = min(branching_factor, n_bodies - next_id)
n_children = int(self.rng.integers(1, max_children + 1))
for _ in range(n_children):
direction = self._random_axis()
distance = self.rng.uniform(1.5, 3.0)
child_pos = parent_pos + direction * distance
child_orient = self._random_orientation()
bodies.append(
RigidBody(
body_id=next_id,
position=child_pos,
orientation=child_orient,
)
)
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
parent_id,
next_id,
parent_pos,
child_pos,
jtype,
axis=axis,
orient_a=bodies[parent_id].orientation,
orient_b=child_orient,
)
)
available_parents.append(next_id)
next_id += 1
joint_id += 1
if next_id >= n_bodies:
break
# Retire parent if it reached branching limit or randomly
if n_children >= branching_factor or self.rng.random() < 0.3:
available_parents.pop(pidx)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_loop_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a single closed loop (ring) of bodies.
The closing constraint introduces redundancy, making this
useful for generating overconstrained training data.
Args:
n_bodies: Bodies in the ring (>= 3).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 3.
"""
if n_bodies < 3:
msg = "Loop assembly requires at least 3 bodies"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
base_radius = max(2.0, n_bodies * 0.4)
for i in range(n_bodies):
angle = 2 * np.pi * i / n_bodies
radius = base_radius + self.rng.uniform(-0.5, 0.5)
x = radius * np.cos(angle)
y = radius * np.sin(angle)
z = float(self.rng.uniform(-0.2, 0.2))
bodies.append(
RigidBody(
body_id=i,
position=np.array([x, y, z]),
orientation=self._random_orientation(),
)
)
shared_axis: np.ndarray | None = None
for i in range(n_bodies):
next_i = (i + 1) % n_bodies
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i,
i,
next_i,
bodies[i].position,
bodies[next_i].position,
jtype,
axis=axis,
orient_a=bodies[i].orientation,
orient_b=bodies[next_i].orientation,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_star_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a star topology with central hub and satellites.
Body 0 is the hub; all other bodies connect directly to it.
Underconstrained because there are no inter-satellite connections.
Args:
n_bodies: Total bodies including hub (>= 2).
joint_types: Single type or list to sample from per joint.
Raises:
ValueError: If *n_bodies* < 2.
"""
if n_bodies < 2:
msg = "Star assembly requires at least 2 bodies"
raise ValueError(msg)
hub_orient = self._random_orientation()
bodies: list[RigidBody] = [
RigidBody(
body_id=0,
position=np.zeros(3),
orientation=hub_orient,
)
]
joints: list[Joint] = []
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
direction = self._random_axis()
distance = self.rng.uniform(2.0, 5.0)
pos = direction * distance
sat_orient = self._random_orientation()
bodies.append(RigidBody(body_id=i, position=pos, orientation=sat_orient))
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
i - 1,
0,
i,
np.zeros(3),
pos,
jtype,
axis=axis,
orient_a=hub_orient,
orient_b=sat_orient,
)
)
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
def generate_mixed_assembly(
self,
n_bodies: int,
joint_types: JointType | list[JointType] = JointType.REVOLUTE,
edge_density: float = 0.3,
*,
grounded: bool = True,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
) -> tuple[list[RigidBody], list[Joint], ConstraintAnalysis]:
"""Generate a mixed topology combining tree and loop elements.
Builds a spanning tree for connectivity, then adds extra edges
based on *edge_density* to create loops and redundancy.
Args:
n_bodies: Number of bodies.
joint_types: Single type or list to sample from per joint.
edge_density: Fraction of non-tree edges to add (0.0-1.0).
Raises:
ValueError: If *edge_density* not in [0.0, 1.0].
"""
if not 0.0 <= edge_density <= 1.0:
msg = "edge_density must be in [0.0, 1.0]"
raise ValueError(msg)
bodies: list[RigidBody] = []
joints: list[Joint] = []
for i in range(n_bodies):
bodies.append(
RigidBody(
body_id=i,
position=self._random_position(),
orientation=self._random_orientation(),
)
)
# Phase 1: spanning tree
joint_id = 0
existing_edges: set[frozenset[int]] = set()
shared_axis: np.ndarray | None = None
for i in range(1, n_bodies):
parent = int(self.rng.integers(0, i))
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
parent,
i,
bodies[parent].position,
bodies[i].position,
jtype,
axis=axis,
orient_a=bodies[parent].orientation,
orient_b=bodies[i].orientation,
)
)
existing_edges.add(frozenset([parent, i]))
joint_id += 1
# Phase 2: extra edges based on density
candidates: list[tuple[int, int]] = []
for i in range(n_bodies):
for j in range(i + 1, n_bodies):
if frozenset([i, j]) not in existing_edges:
candidates.append((i, j))
n_extra = int(edge_density * len(candidates))
self.rng.shuffle(candidates)
for a, b in candidates[:n_extra]:
jtype = self._select_joint_type(joint_types)
axis, shared_axis = self._resolve_axis(
axis_strategy,
parallel_axis_prob,
shared_axis,
)
joints.append(
self._create_joint(
joint_id,
a,
b,
bodies[a].position,
bodies[b].position,
jtype,
axis=axis,
orient_a=bodies[a].orientation,
orient_b=bodies[b].orientation,
)
)
joint_id += 1
analysis = analyze_assembly(
bodies,
joints,
ground_body=0 if grounded else None,
)
return bodies, joints, analysis
# ------------------------------------------------------------------
# Batch generation
# ------------------------------------------------------------------
def generate_training_batch(
self,
batch_size: int = 100,
n_bodies_range: tuple[int, int] | None = None,
complexity_tier: ComplexityTier | None = None,
*,
axis_strategy: AxisStrategy = "random",
parallel_axis_prob: float = 0.0,
grounded_ratio: float = 1.0,
) -> list[dict[str, Any]]:
"""Generate a batch of labeled training examples.
Each example contains body positions, joint descriptions,
per-joint independence labels, and assembly-level classification.
Args:
batch_size: Number of assemblies to generate.
n_bodies_range: ``(min, max_exclusive)`` body count.
Overridden by *complexity_tier* when both are given.
complexity_tier: Predefined range (``"simple"`` / ``"medium"``
/ ``"complex"``). Overrides *n_bodies_range*.
axis_strategy: Axis sampling strategy for joint axes.
parallel_axis_prob: Probability of parallel axis injection.
grounded_ratio: Fraction of examples that are grounded.
"""
if complexity_tier is not None:
n_bodies_range = COMPLEXITY_RANGES[complexity_tier]
elif n_bodies_range is None:
n_bodies_range = (3, 8)
_joint_pool = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
JointType.FIXED,
]
geo_kw: dict[str, Any] = {
"axis_strategy": axis_strategy,
"parallel_axis_prob": parallel_axis_prob,
}
examples: list[dict[str, Any]] = []
for i in range(batch_size):
n = int(self.rng.integers(*n_bodies_range))
gen_idx = int(self.rng.integers(7))
grounded = bool(self.rng.random() < grounded_ratio)
if gen_idx == 0:
_chain_types = [
JointType.REVOLUTE,
JointType.BALL,
JointType.CYLINDRICAL,
]
jtype = _chain_types[int(self.rng.integers(len(_chain_types)))]
bodies, joints, analysis = self.generate_chain_assembly(
n,
jtype,
grounded=grounded,
**geo_kw,
)
gen_name = "chain"
elif gen_idx == 1:
bodies, joints, analysis = self.generate_rigid_assembly(
n,
grounded=grounded,
**geo_kw,
)
gen_name = "rigid"
elif gen_idx == 2:
extra = int(self.rng.integers(1, 4))
bodies, joints, analysis = self.generate_overconstrained_assembly(
n,
extra,
grounded=grounded,
**geo_kw,
)
gen_name = "overconstrained"
elif gen_idx == 3:
branching = int(self.rng.integers(2, 5))
bodies, joints, analysis = self.generate_tree_assembly(
n,
_joint_pool,
branching,
grounded=grounded,
**geo_kw,
)
gen_name = "tree"
elif gen_idx == 4:
n = max(n, 3)
bodies, joints, analysis = self.generate_loop_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "loop"
elif gen_idx == 5:
n = max(n, 2)
bodies, joints, analysis = self.generate_star_assembly(
n,
_joint_pool,
grounded=grounded,
**geo_kw,
)
gen_name = "star"
else:
density = float(self.rng.uniform(0.2, 0.5))
bodies, joints, analysis = self.generate_mixed_assembly(
n,
_joint_pool,
density,
grounded=grounded,
**geo_kw,
)
gen_name = "mixed"
# Produce ground truth labels (includes ConstraintAnalysis)
ground = 0 if grounded else None
labels = label_assembly(bodies, joints, ground_body=ground)
analysis = labels.analysis
# Build per-joint labels from edge results
joint_labels: dict[int, dict[str, int]] = {}
for result in analysis.per_edge_results:
jid = result["joint_id"]
if jid not in joint_labels:
joint_labels[jid] = {
"independent_constraints": 0,
"redundant_constraints": 0,
"total_constraints": 0,
}
joint_labels[jid]["total_constraints"] += 1
if result["independent"]:
joint_labels[jid]["independent_constraints"] += 1
else:
joint_labels[jid]["redundant_constraints"] += 1
examples.append(
{
"example_id": i,
"generator_type": gen_name,
"grounded": grounded,
"n_bodies": len(bodies),
"n_joints": len(joints),
"body_positions": [b.position.tolist() for b in bodies],
"body_orientations": [b.orientation.tolist() for b in bodies],
"joints": [
{
"joint_id": j.joint_id,
"body_a": j.body_a,
"body_b": j.body_b,
"type": j.joint_type.name,
"axis": j.axis.tolist(),
}
for j in joints
],
"joint_labels": joint_labels,
"labels": labels.to_dict(),
"assembly_classification": (analysis.combinatorial_classification),
"is_rigid": analysis.is_rigid,
"is_minimally_rigid": analysis.is_minimally_rigid,
"internal_dof": analysis.jacobian_internal_dof,
"geometric_degeneracies": (analysis.geometric_degeneracies),
}
)
return examples

517
solver/datagen/jacobian.py Normal file
View File

@@ -0,0 +1,517 @@
"""Numerical Jacobian rank verification for assembly constraint analysis.
Builds the constraint Jacobian matrix and analyzes its numerical rank
to detect geometric degeneracies that the combinatorial pebble game
cannot identify (e.g., parallel revolute axes creating hidden dependencies).
References:
- Chappuis, "Constraints Derivation for Rigid Body Simulation in 3D"
"""
from __future__ import annotations
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.types import Joint, JointType, RigidBody
if TYPE_CHECKING:
from typing import Any
__all__ = ["JacobianVerifier"]
class JacobianVerifier:
"""Builds and analyzes the constraint Jacobian for numerical rank check.
The pebble game gives a combinatorial *necessary* condition for
rigidity. However, geometric special cases (e.g., all revolute axes
parallel, creating a hidden dependency) require numerical verification.
For each joint, we construct the constraint Jacobian rows that map
the 6n-dimensional generalized velocity vector to the constraint
violation rates. The rank of this Jacobian equals the number of
truly independent constraints.
The generalized velocity vector for n bodies is::
v = [v1_x, v1_y, v1_z, w1_x, w1_y, w1_z, ..., vn_x, ..., wn_z]
Each scalar constraint C_i contributes one row to J such that::
dC_i/dt = J_i @ v = 0
"""
def __init__(self, bodies: list[RigidBody]) -> None:
self.bodies = {b.body_id: b for b in bodies}
self.body_index = {b.body_id: i for i, b in enumerate(bodies)}
self.n_bodies = len(bodies)
self.jacobian_rows: list[np.ndarray] = []
self.row_labels: list[dict[str, Any]] = []
def _body_cols(self, body_id: int) -> tuple[int, int]:
"""Return the column range [start, end) for a body in J."""
idx = self.body_index[body_id]
return idx * 6, (idx + 1) * 6
def add_joint_constraints(self, joint: Joint) -> int:
"""Add Jacobian rows for all scalar constraints of a joint.
Returns the number of rows added.
"""
builder = {
JointType.FIXED: self._build_fixed,
JointType.REVOLUTE: self._build_revolute,
JointType.CYLINDRICAL: self._build_cylindrical,
JointType.SLIDER: self._build_slider,
JointType.BALL: self._build_ball,
JointType.PLANAR: self._build_planar,
JointType.DISTANCE: self._build_distance,
JointType.PARALLEL: self._build_parallel,
JointType.PERPENDICULAR: self._build_perpendicular,
JointType.UNIVERSAL: self._build_universal,
JointType.SCREW: self._build_screw,
}
rows_before = len(self.jacobian_rows)
builder[joint.joint_type](joint)
return len(self.jacobian_rows) - rows_before
def _make_row(self) -> np.ndarray:
"""Create a zero row of width 6*n_bodies."""
return np.zeros(6 * self.n_bodies)
def _skew(self, v: np.ndarray) -> np.ndarray:
"""Skew-symmetric matrix for cross product: ``skew(v) @ w = v x w``."""
return np.array(
[
[0, -v[2], v[1]],
[v[2], 0, -v[0]],
[-v[1], v[0], 0],
]
)
# --- Ball-and-socket (spherical) joint: 3 translation constraints ---
def _build_ball(self, joint: Joint) -> None:
"""Ball joint: coincident point constraint.
``C_trans = (x_b + R_b @ r_b) - (x_a + R_a @ r_a) = 0``
(3 equations)
Jacobian rows (for each of x, y, z):
body_a linear: -I
body_a angular: +skew(R_a @ r_a)
body_b linear: +I
body_b angular: -skew(R_b @ r_b)
"""
# Use anchor positions directly as world-frame offsets
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
e = np.zeros(3)
e[axis_idx] = 1.0
row[col_a_start : col_a_start + 3] = -e
row[col_a_start + 3 : col_a_end] = np.cross(r_a, e)
row[col_b_start : col_b_start + 3] = e
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, e)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "ball_translation",
"axis": axis_idx,
}
)
# --- Fixed joint: 3 translation + 3 rotation constraints ---
def _build_fixed(self, joint: Joint) -> None:
"""Fixed joint = ball joint + locked rotation.
Translation part: same as ball joint (3 rows).
Rotation part: relative angular velocity must be zero (3 rows).
"""
self._build_ball(joint)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "fixed_rotation",
"axis": axis_idx,
}
)
# --- Revolute (hinge) joint: 3 translation + 2 rotation constraints ---
def _build_revolute(self, joint: Joint) -> None:
"""Revolute joint: rotation only about one axis.
Translation: same as ball (3 rows).
Rotation: relative angular velocity must be parallel to hinge axis.
"""
self._build_ball(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "revolute_rotation",
"perp_axis": i,
}
)
# --- Cylindrical joint: 2 translation + 2 rotation constraints ---
def _build_cylindrical(self, joint: Joint) -> None:
"""Cylindrical joint: allows rotation + translation along one axis.
Translation: constrain motion perpendicular to axis (2 rows).
Rotation: constrain rotation perpendicular to axis (2 rows).
"""
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start : col_a_start + 3] = -t
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
row[col_b_start : col_b_start + 3] = t
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "cylindrical_translation",
"perp_axis": i,
}
)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "cylindrical_rotation",
"perp_axis": i,
}
)
# --- Slider (prismatic) joint: 2 translation + 3 rotation constraints ---
def _build_slider(self, joint: Joint) -> None:
"""Slider/prismatic joint: translation along one axis only.
Translation: perpendicular translation constrained (2 rows).
Rotation: all relative rotation constrained (3 rows).
"""
axis = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(axis)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start : col_a_start + 3] = -t
row[col_a_start + 3 : col_a_end] = np.cross(r_a, t)
row[col_b_start : col_b_start + 3] = t
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, t)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "slider_translation",
"perp_axis": i,
}
)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "slider_rotation",
"axis": axis_idx,
}
)
# --- Planar joint: 1 translation + 2 rotation constraints ---
def _build_planar(self, joint: Joint) -> None:
"""Planar joint: constrains to a plane.
Translation: motion along plane normal constrained (1 row).
Rotation: rotation about axes in the plane constrained (2 rows).
"""
normal = joint.axis / np.linalg.norm(joint.axis)
t1, t2 = self._perpendicular_pair(normal)
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -normal
row[col_a_start + 3 : col_a_end] = np.cross(r_a, normal)
row[col_b_start : col_b_start + 3] = normal
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, normal)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "planar_translation",
}
)
for i, t in enumerate((t1, t2)):
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -t
row[col_b_start + 3 : col_b_start + 6] = t
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "planar_rotation",
"perp_axis": i,
}
)
# --- Distance constraint: 1 scalar ---
def _build_distance(self, joint: Joint) -> None:
"""Distance constraint: ``||p_b - p_a|| = d``.
Single row: ``direction . (v_b + w_b x r_b - v_a - w_a x r_a) = 0``
where ``direction = normalized(p_b - p_a)``.
"""
p_a = joint.anchor_a
p_b = joint.anchor_b
diff = p_b - p_a
dist = np.linalg.norm(diff)
direction = np.array([1.0, 0.0, 0.0]) if dist < 1e-12 else diff / dist
r_a = joint.anchor_a - self.bodies[joint.body_a].position
r_b = joint.anchor_b - self.bodies[joint.body_b].position
col_a_start, col_a_end = self._body_cols(joint.body_a)
col_b_start, col_b_end = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -direction
row[col_a_start + 3 : col_a_end] = np.cross(r_a, direction)
row[col_b_start : col_b_start + 3] = direction
row[col_b_start + 3 : col_b_end] = -np.cross(r_b, direction)
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "distance",
}
)
# --- Parallel constraint: 3 rotation constraints ---
def _build_parallel(self, joint: Joint) -> None:
"""Parallel: all relative rotation constrained (same as fixed rotation).
In practice only 2 of 3 are independent for a single axis, but
we emit 3 and let the rank check sort it out.
"""
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
for axis_idx in range(3):
row = self._make_row()
row[col_a_start + 3 + axis_idx] = -1.0
row[col_b_start + 3 + axis_idx] = 1.0
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "parallel_rotation",
"axis": axis_idx,
}
)
# --- Perpendicular constraint: 1 angular ---
def _build_perpendicular(self, joint: Joint) -> None:
"""Perpendicular: single dot-product angular constraint."""
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -axis
row[col_b_start + 3 : col_b_start + 6] = axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "perpendicular",
}
)
# --- Universal (Cardan) joint: 3 translation + 1 rotation ---
def _build_universal(self, joint: Joint) -> None:
"""Universal joint: ball + one rotation constraint.
Allows rotation about two axes, constrains rotation about the third.
"""
self._build_ball(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start + 3 : col_a_start + 6] = -axis
row[col_b_start + 3 : col_b_start + 6] = axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "universal_rotation",
}
)
# --- Screw (helical) joint: 2 translation + 2 rotation + 1 coupled ---
def _build_screw(self, joint: Joint) -> None:
"""Screw joint: coupled rotation-translation along axis.
Like cylindrical but with a coupling constraint:
``v_axial - pitch * w_axial = 0``
"""
self._build_cylindrical(joint)
axis = joint.axis / np.linalg.norm(joint.axis)
col_a_start, _ = self._body_cols(joint.body_a)
col_b_start, _ = self._body_cols(joint.body_b)
row = self._make_row()
row[col_a_start : col_a_start + 3] = -axis
row[col_b_start : col_b_start + 3] = axis
row[col_a_start + 3 : col_a_start + 6] = joint.pitch * axis
row[col_b_start + 3 : col_b_start + 6] = -joint.pitch * axis
self.jacobian_rows.append(row)
self.row_labels.append(
{
"joint_id": joint.joint_id,
"type": "screw_coupling",
}
)
# --- Utilities ---
def _perpendicular_pair(self, axis: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Generate two unit vectors perpendicular to *axis* and each other."""
if abs(axis[0]) < 0.9:
t1 = np.cross(axis, np.array([1.0, 0, 0]))
else:
t1 = np.cross(axis, np.array([0, 1.0, 0]))
t1 /= np.linalg.norm(t1)
t2 = np.cross(axis, t1)
t2 /= np.linalg.norm(t2)
return t1, t2
def get_jacobian(self) -> np.ndarray:
"""Return the full constraint Jacobian matrix."""
if not self.jacobian_rows:
return np.zeros((0, 6 * self.n_bodies))
return np.array(self.jacobian_rows)
def numerical_rank(self, tol: float = 1e-8) -> int:
"""Compute the numerical rank of the constraint Jacobian via SVD.
This is the number of truly independent scalar constraints,
accounting for geometric degeneracies that the combinatorial
pebble game cannot detect.
"""
j = self.get_jacobian()
if j.size == 0:
return 0
sv = np.linalg.svd(j, compute_uv=False)
return int(np.sum(sv > tol))
def find_dependencies(self, tol: float = 1e-8) -> list[int]:
"""Identify which constraint rows are numerically dependent.
Returns indices of rows that can be removed without changing
the Jacobian's rank.
"""
j = self.get_jacobian()
if j.size == 0:
return []
n_rows = j.shape[0]
dependent: list[int] = []
current = np.zeros((0, j.shape[1]))
current_rank = 0
for i in range(n_rows):
candidate = np.vstack([current, j[i : i + 1, :]]) if current.size else j[i : i + 1, :]
sv = np.linalg.svd(candidate, compute_uv=False)
new_rank = int(np.sum(sv > tol))
if new_rank > current_rank:
current = candidate
current_rank = new_rank
else:
dependent.append(i)
return dependent

394
solver/datagen/labeling.py Normal file
View File

@@ -0,0 +1,394 @@
"""Ground truth labeling pipeline for synthetic assemblies.
Produces rich per-constraint, per-joint, per-body, and assembly-level
labels by running both the pebble game and Jacobian verification and
correlating their results.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
if TYPE_CHECKING:
from typing import Any
__all__ = ["AssemblyLabels", "label_assembly"]
_GROUND_ID = -1
_SVD_TOL = 1e-8
# ---------------------------------------------------------------------------
# Label dataclasses
# ---------------------------------------------------------------------------
@dataclass
class ConstraintLabel:
"""Per scalar-constraint label combining both analysis methods."""
joint_id: int
constraint_idx: int
pebble_independent: bool
jacobian_independent: bool
@dataclass
class JointLabel:
"""Aggregated constraint counts for a single joint."""
joint_id: int
independent_count: int
redundant_count: int
total: int
@dataclass
class BodyDofLabel:
"""Per-body DOF signature from nullspace projection."""
body_id: int
translational_dof: int
rotational_dof: int
@dataclass
class AssemblyLabel:
"""Assembly-wide summary label."""
classification: str
total_dof: int
redundant_count: int
is_rigid: bool
is_minimally_rigid: bool
has_degeneracy: bool
@dataclass
class AssemblyLabels:
"""Complete ground truth labels for an assembly."""
per_constraint: list[ConstraintLabel]
per_joint: list[JointLabel]
per_body: list[BodyDofLabel]
assembly: AssemblyLabel
analysis: ConstraintAnalysis
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"per_constraint": [
{
"joint_id": c.joint_id,
"constraint_idx": c.constraint_idx,
"pebble_independent": c.pebble_independent,
"jacobian_independent": c.jacobian_independent,
}
for c in self.per_constraint
],
"per_joint": [
{
"joint_id": j.joint_id,
"independent_count": j.independent_count,
"redundant_count": j.redundant_count,
"total": j.total,
}
for j in self.per_joint
],
"per_body": [
{
"body_id": b.body_id,
"translational_dof": b.translational_dof,
"rotational_dof": b.rotational_dof,
}
for b in self.per_body
],
"assembly": {
"classification": self.assembly.classification,
"total_dof": self.assembly.total_dof,
"redundant_count": self.assembly.redundant_count,
"is_rigid": self.assembly.is_rigid,
"is_minimally_rigid": self.assembly.is_minimally_rigid,
"has_degeneracy": self.assembly.has_degeneracy,
},
}
# ---------------------------------------------------------------------------
# Per-body DOF from nullspace projection
# ---------------------------------------------------------------------------
def _compute_per_body_dof(
j_reduced: np.ndarray,
body_ids: list[int],
ground_body: int | None,
body_index: dict[int, int],
) -> list[BodyDofLabel]:
"""Compute translational and rotational DOF per body.
Uses SVD nullspace projection: for each body, extract its
translational (3 cols) and rotational (3 cols) components
from the nullspace basis and compute ranks.
"""
# Build column index mapping for the reduced Jacobian
# (ground body columns have been removed)
col_map: dict[int, int] = {}
col_idx = 0
for bid in body_ids:
if bid == ground_body:
continue
col_map[bid] = col_idx
col_idx += 1
results: list[BodyDofLabel] = []
if j_reduced.size == 0:
# No constraints — every body is fully free
for bid in body_ids:
if bid == ground_body:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
else:
results.append(BodyDofLabel(body_id=bid, translational_dof=3, rotational_dof=3))
return results
# Full SVD to get nullspace
_u, s, vh = np.linalg.svd(j_reduced, full_matrices=True)
rank = int(np.sum(s > _SVD_TOL))
n_cols = j_reduced.shape[1]
if rank >= n_cols:
# Fully constrained — no nullspace
for bid in body_ids:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
return results
# Nullspace basis: rows of Vh beyond the rank
nullspace = vh[rank:] # shape: (n_cols - rank, n_cols)
for bid in body_ids:
if bid == ground_body:
results.append(BodyDofLabel(body_id=bid, translational_dof=0, rotational_dof=0))
continue
idx = col_map[bid]
trans_cols = nullspace[:, idx * 6 : idx * 6 + 3]
rot_cols = nullspace[:, idx * 6 + 3 : idx * 6 + 6]
# Rank of each block = DOF in that category
if trans_cols.size > 0:
sv_t = np.linalg.svd(trans_cols, compute_uv=False)
t_dof = int(np.sum(sv_t > _SVD_TOL))
else:
t_dof = 0
if rot_cols.size > 0:
sv_r = np.linalg.svd(rot_cols, compute_uv=False)
r_dof = int(np.sum(sv_r > _SVD_TOL))
else:
r_dof = 0
results.append(BodyDofLabel(body_id=bid, translational_dof=t_dof, rotational_dof=r_dof))
return results
# ---------------------------------------------------------------------------
# Main labeling function
# ---------------------------------------------------------------------------
def label_assembly(
bodies: list[RigidBody],
joints: list[Joint],
ground_body: int | None = None,
) -> AssemblyLabels:
"""Produce complete ground truth labels for an assembly.
Runs both the pebble game and Jacobian verification internally,
then correlates their results into per-constraint, per-joint,
per-body, and assembly-level labels.
Args:
bodies: Rigid bodies in the assembly.
joints: Joints connecting the bodies.
ground_body: If set, this body is fixed to the world.
Returns:
AssemblyLabels with full label set and embedded ConstraintAnalysis.
"""
# ---- Pebble Game ----
pg = PebbleGame3D()
all_edge_results: list[dict[str, Any]] = []
if ground_body is not None:
pg.add_body(_GROUND_ID)
for body in bodies:
pg.add_body(body.body_id)
if ground_body is not None:
ground_joint = Joint(
joint_id=-1,
body_a=ground_body,
body_b=_GROUND_ID,
joint_type=JointType.FIXED,
anchor_a=bodies[0].position if bodies else np.zeros(3),
anchor_b=bodies[0].position if bodies else np.zeros(3),
)
pg.add_joint(ground_joint)
for joint in joints:
results = pg.add_joint(joint)
all_edge_results.extend(results)
grounded = ground_body is not None
combinatorial_independent = len(pg.state.independent_edges)
raw_dof = pg.get_dof()
ground_offset = 6 if grounded else 0
effective_dof = raw_dof - ground_offset
effective_internal_dof = max(0, effective_dof - (0 if grounded else 6))
redundant_count = pg.get_redundant_count()
if redundant_count > 0 and effective_internal_dof > 0:
classification = "mixed"
elif redundant_count > 0:
classification = "overconstrained"
elif effective_internal_dof > 0:
classification = "underconstrained"
else:
classification = "well-constrained"
# ---- Jacobian Verification ----
verifier = JacobianVerifier(bodies)
for joint in joints:
verifier.add_joint_constraints(joint)
j_full = verifier.get_jacobian()
j_reduced = j_full.copy()
if ground_body is not None and j_reduced.size > 0:
idx = verifier.body_index[ground_body]
cols_to_remove = list(range(idx * 6, (idx + 1) * 6))
j_reduced = np.delete(j_reduced, cols_to_remove, axis=1)
if j_reduced.size > 0:
sv = np.linalg.svd(j_reduced, compute_uv=False)
jacobian_rank = int(np.sum(sv > _SVD_TOL))
else:
jacobian_rank = 0
n_cols = j_reduced.shape[1] if j_reduced.size > 0 else 6 * len(bodies)
jacobian_nullity = n_cols - jacobian_rank
dependent_rows = verifier.find_dependencies()
dependent_set = set(dependent_rows)
trivial_dof = 0 if grounded else 6
jacobian_internal_dof = jacobian_nullity - trivial_dof
geometric_degeneracies = max(0, combinatorial_independent - jacobian_rank)
is_rigid = jacobian_nullity <= trivial_dof
is_minimally_rigid = is_rigid and len(dependent_rows) == 0
# ---- Per-constraint labels ----
# Map Jacobian rows to (joint_id, constraint_index).
# Rows are added contiguously per joint in the same order as joints.
row_to_joint: list[tuple[int, int]] = []
for joint in joints:
dof = joint.joint_type.dof
for ci in range(dof):
row_to_joint.append((joint.joint_id, ci))
per_constraint: list[ConstraintLabel] = []
for edge_idx, edge_result in enumerate(all_edge_results):
jid = edge_result["joint_id"]
ci = edge_result["constraint_index"]
pebble_indep = edge_result["independent"]
# Find matching Jacobian row
jacobian_indep = True
if edge_idx < len(row_to_joint):
row_idx = edge_idx
jacobian_indep = row_idx not in dependent_set
per_constraint.append(
ConstraintLabel(
joint_id=jid,
constraint_idx=ci,
pebble_independent=pebble_indep,
jacobian_independent=jacobian_indep,
)
)
# ---- Per-joint labels ----
joint_agg: dict[int, JointLabel] = {}
for cl in per_constraint:
if cl.joint_id not in joint_agg:
joint_agg[cl.joint_id] = JointLabel(
joint_id=cl.joint_id,
independent_count=0,
redundant_count=0,
total=0,
)
jl = joint_agg[cl.joint_id]
jl.total += 1
if cl.pebble_independent:
jl.independent_count += 1
else:
jl.redundant_count += 1
per_joint = [joint_agg[j.joint_id] for j in joints if j.joint_id in joint_agg]
# ---- Per-body DOF labels ----
body_ids = [b.body_id for b in bodies]
per_body = _compute_per_body_dof(
j_reduced,
body_ids,
ground_body,
verifier.body_index,
)
# ---- Assembly label ----
assembly_label = AssemblyLabel(
classification=classification,
total_dof=max(0, jacobian_internal_dof),
redundant_count=redundant_count,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
has_degeneracy=geometric_degeneracies > 0,
)
# ---- ConstraintAnalysis (for backward compat) ----
analysis = ConstraintAnalysis(
combinatorial_dof=effective_dof,
combinatorial_internal_dof=effective_internal_dof,
combinatorial_redundant=redundant_count,
combinatorial_classification=classification,
per_edge_results=all_edge_results,
jacobian_rank=jacobian_rank,
jacobian_nullity=jacobian_nullity,
jacobian_internal_dof=max(0, jacobian_internal_dof),
numerically_dependent=dependent_rows,
geometric_degeneracies=geometric_degeneracies,
is_rigid=is_rigid,
is_minimally_rigid=is_minimally_rigid,
)
return AssemblyLabels(
per_constraint=per_constraint,
per_joint=per_joint,
per_body=per_body,
assembly=assembly_label,
analysis=analysis,
)

View File

@@ -0,0 +1,258 @@
"""(6,6)-Pebble game for 3D body-bar-hinge rigidity analysis.
Implements the pebble game algorithm adapted for CAD assembly constraint
graphs. Each rigid body has 6 DOF (3 translation + 3 rotation). Joints
between bodies remove DOF according to their type.
The pebble game provides a fast combinatorial *necessary* condition for
rigidity via Tay's theorem. It does not detect geometric degeneracies —
use :class:`solver.datagen.jacobian.JacobianVerifier` for the *sufficient*
condition.
References:
- Lee & Streinu, "Pebble Game Algorithms and Sparse Graphs", 2008
- Jacobs & Hendrickson, "An Algorithm for Two-Dimensional Rigidity
Percolation: The Pebble Game", J. Comput. Phys., 1997
- Tay, "Rigidity of Multigraphs I: Linking Rigid Bodies in n-space", 1984
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from solver.datagen.types import Joint, PebbleState
if TYPE_CHECKING:
from typing import Any
__all__ = ["PebbleGame3D"]
class PebbleGame3D:
"""Implements the (6,6)-pebble game for 3D body-bar-hinge frameworks.
For body-bar-hinge structures in 3D, Tay's theorem states that a
multigraph G on n vertices is generically minimally rigid iff:
|E| = 6n - 6 and |E'| <= 6n' - 6 for all subgraphs (n' >= 2)
The (6,6)-pebble game tests this sparsity condition incrementally.
Each vertex starts with 6 pebbles (representing 6 DOF). To insert
an edge, we need to collect 6+1=7 pebbles on its two endpoints.
If we can, the edge is independent (removes a DOF). If not, it's
redundant (overconstrained).
In the CAD assembly context:
- Vertices = rigid bodies
- Edges = scalar constraints from joints
- A revolute joint (5 DOF removed) maps to 5 multigraph edges
- A fixed joint (6 DOF removed) maps to 6 multigraph edges
"""
K = 6 # Pebbles per vertex (DOF per rigid body in 3D)
L = 6 # Sparsity parameter: need K+1=7 pebbles to accept edge
def __init__(self) -> None:
self.state = PebbleState()
self._edge_counter = 0
self._bodies: set[int] = set()
def add_body(self, body_id: int) -> None:
"""Register a rigid body (vertex) with K=6 free pebbles."""
if body_id in self._bodies:
return
self._bodies.add(body_id)
self.state.free_pebbles[body_id] = self.K
self.state.incoming[body_id] = set()
self.state.outgoing[body_id] = set()
def add_joint(self, joint: Joint) -> list[dict[str, Any]]:
"""Expand a joint into multigraph edges and test each for independence.
A joint that removes ``d`` DOF becomes ``d`` edges in the multigraph.
Each edge is tested individually via the pebble game.
Returns a list of dicts, one per scalar constraint, with:
- edge_id: int
- independent: bool
- dof_remaining: int (total free pebbles after this edge)
"""
self.add_body(joint.body_a)
self.add_body(joint.body_b)
num_constraints = joint.joint_type.dof
results: list[dict[str, Any]] = []
for i in range(num_constraints):
edge_id = self._edge_counter
self._edge_counter += 1
independent = self._try_insert_edge(edge_id, joint.body_a, joint.body_b)
total_free = sum(self.state.free_pebbles.values())
results.append(
{
"edge_id": edge_id,
"joint_id": joint.joint_id,
"constraint_index": i,
"independent": independent,
"dof_remaining": total_free,
}
)
return results
def _try_insert_edge(self, edge_id: int, u: int, v: int) -> bool:
"""Try to insert a directed edge between u and v.
The edge is accepted (independent) iff we can collect L+1 = 7
pebbles on the two endpoints {u, v} combined.
If accepted, one pebble is consumed and the edge is directed
away from the vertex that gives up the pebble.
"""
# Count current free pebbles on u and v
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
# Try to gather enough pebbles via DFS reachability search
if available < self.L + 1:
needed = (self.L + 1) - available
# Try to free pebbles by searching from u first, then v
for target in (u, v):
while needed > 0:
found = self._search_and_collect(target, frozenset({u, v}))
if not found:
break
needed -= 1
# Recheck after collection attempts
available = self.state.free_pebbles[u] + self.state.free_pebbles[v]
if available >= self.L + 1:
# Accept: consume a pebble from whichever endpoint has one
source = u if self.state.free_pebbles[u] > 0 else v
self.state.free_pebbles[source] -= 1
self.state.directed_edges[edge_id] = (source, v if source == u else u)
self.state.outgoing[source].add((edge_id, v if source == u else u))
target = v if source == u else u
self.state.incoming[target].add((edge_id, source))
self.state.independent_edges.add(edge_id)
return True
else:
# Reject: edge is redundant (overconstrained)
self.state.redundant_edges.add(edge_id)
return False
def _search_and_collect(self, target: int, forbidden: frozenset[int]) -> bool:
"""DFS to find a free pebble reachable from *target* and move it.
Follows directed edges *backwards* (from destination to source)
to find a vertex with a free pebble that isn't in *forbidden*.
When found, reverses the path to move the pebble to *target*.
Returns True if a pebble was successfully moved to target.
"""
# BFS/DFS through the directed graph following outgoing edges
# from target. An outgoing edge (target -> w) means target spent
# a pebble on that edge. If we can find a vertex with a free
# pebble, we reverse edges along the path to move it.
visited: set[int] = set()
# Stack: (current_vertex, path_of_edge_ids_to_reverse)
stack: list[tuple[int, list[int]]] = [(target, [])]
while stack:
current, path = stack.pop()
if current in visited:
continue
visited.add(current)
# Check if current vertex (not in forbidden, not target)
# has a free pebble
if (
current != target
and current not in forbidden
and self.state.free_pebbles[current] > 0
):
# Found a pebble — reverse the path
self._reverse_path(path, current)
return True
# Follow outgoing edges from current vertex
for eid, neighbor in self.state.outgoing.get(current, set()):
if neighbor not in visited:
stack.append((neighbor, [*path, eid]))
return False
def _reverse_path(self, edge_ids: list[int], pebble_source: int) -> None:
"""Reverse directed edges along a path, moving a pebble to the start.
The pebble at *pebble_source* is consumed by the last edge in
the path, and a pebble is freed at the path's start vertex.
"""
if not edge_ids:
return
# Reverse each edge in the path
for eid in edge_ids:
old_source, old_target = self.state.directed_edges[eid]
# Remove from adjacency
self.state.outgoing[old_source].discard((eid, old_target))
self.state.incoming[old_target].discard((eid, old_source))
# Reverse direction
self.state.directed_edges[eid] = (old_target, old_source)
self.state.outgoing[old_target].add((eid, old_source))
self.state.incoming[old_source].add((eid, old_target))
# Move pebble counts: source loses one, first vertex in path gains one
self.state.free_pebbles[pebble_source] -= 1
# After all reversals, the vertex at the beginning of the
# search path gains a pebble
_first_src, first_tgt = self.state.directed_edges[edge_ids[0]]
self.state.free_pebbles[first_tgt] += 1
def get_dof(self) -> int:
"""Total remaining DOF = sum of free pebbles.
For a fully rigid assembly, this should be 6 (the trivial rigid
body motions of the whole assembly). Internal DOF = total - 6.
"""
return sum(self.state.free_pebbles.values())
def get_internal_dof(self) -> int:
"""Internal (non-trivial) degrees of freedom."""
return max(0, self.get_dof() - 6)
def is_rigid(self) -> bool:
"""Combinatorial rigidity check: rigid iff at most 6 pebbles remain."""
return self.get_dof() <= self.L
def get_redundant_count(self) -> int:
"""Number of redundant (overconstrained) scalar constraints."""
return len(self.state.redundant_edges)
def classify_assembly(self, *, grounded: bool = False) -> str:
"""Classify the assembly state.
Args:
grounded: If True, the baseline trivial DOF is 0 (not 6),
because the ground body's 6 DOF were removed.
"""
total_dof = self.get_dof()
redundant = self.get_redundant_count()
baseline = 0 if grounded else self.L
if redundant > 0 and total_dof > baseline:
return "mixed" # Both under and over-constrained regions
elif redundant > 0:
return "overconstrained"
elif total_dof > baseline:
return "underconstrained"
elif total_dof == baseline:
return "well-constrained"
else:
return "overconstrained"

144
solver/datagen/types.py Normal file
View File

@@ -0,0 +1,144 @@
"""Shared data types for assembly constraint analysis.
Types ported from the pebble-game synthetic data generator for reuse
across the solver package (data generation, training, inference).
"""
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"ConstraintAnalysis",
"Joint",
"JointType",
"PebbleState",
"RigidBody",
]
# ---------------------------------------------------------------------------
# Joint definitions: each joint type removes a known number of DOF
# ---------------------------------------------------------------------------
class JointType(enum.Enum):
"""Standard CAD joint types with their DOF-removal counts.
Each joint between two 6-DOF rigid bodies removes a specific number
of relative degrees of freedom. In the body-bar-hinge multigraph
representation, each joint maps to a number of edges equal to the
DOF it removes.
Values are ``(ordinal, dof_removed)`` tuples so that joint types
sharing the same DOF count remain distinct enum members. Use the
:attr:`dof` property to get the scalar constraint count.
"""
FIXED = (0, 6) # Locks all relative motion
REVOLUTE = (1, 5) # Allows rotation about one axis
CYLINDRICAL = (2, 4) # Allows rotation + translation along one axis
SLIDER = (3, 5) # Allows translation along one axis (prismatic)
BALL = (4, 3) # Allows rotation about a point (spherical)
PLANAR = (5, 3) # Allows 2D translation + rotation normal to plane
SCREW = (6, 5) # Coupled rotation-translation (helical)
UNIVERSAL = (7, 4) # Two rotational DOF (Cardan/U-joint)
PARALLEL = (8, 3) # Forces parallel orientation (3 rotation constraints)
PERPENDICULAR = (9, 1) # Single angular constraint
DISTANCE = (10, 1) # Single scalar distance constraint
@property
def dof(self) -> int:
"""Number of scalar constraints (DOF removed) by this joint type."""
return self.value[1]
# ---------------------------------------------------------------------------
# Data structures
# ---------------------------------------------------------------------------
@dataclass
class RigidBody:
"""A rigid body in the assembly with pose and geometry info."""
body_id: int
position: np.ndarray = field(default_factory=lambda: np.zeros(3))
orientation: np.ndarray = field(default_factory=lambda: np.eye(3))
# Anchor points for joints, in local frame
# Populated when joints reference specific geometry
local_anchors: dict[str, np.ndarray] = field(default_factory=dict)
@dataclass
class Joint:
"""A joint connecting two rigid bodies."""
joint_id: int
body_a: int # Index of first body
body_b: int # Index of second body
joint_type: JointType
# Joint parameters in world frame
anchor_a: np.ndarray = field(default_factory=lambda: np.zeros(3))
anchor_b: np.ndarray = field(default_factory=lambda: np.zeros(3))
axis: np.ndarray = field(
default_factory=lambda: np.array([0.0, 0.0, 1.0]),
)
# For screw joints
pitch: float = 0.0
@dataclass
class PebbleState:
"""Tracks the state of the pebble game on the multigraph."""
# Number of free pebbles per body (vertex). Starts at 6.
free_pebbles: dict[int, int] = field(default_factory=dict)
# Directed edges: edge_id -> (source_body, target_body)
# Edge is directed away from the body that "spent" a pebble.
directed_edges: dict[int, tuple[int, int]] = field(default_factory=dict)
# Track which edges are independent vs redundant
independent_edges: set[int] = field(default_factory=set)
redundant_edges: set[int] = field(default_factory=set)
# Adjacency: body_id -> set of (edge_id, neighbor_body_id)
# Following directed edges *towards* a body (incoming edges)
incoming: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
# Outgoing edges from a body
outgoing: dict[int, set[tuple[int, int]]] = field(default_factory=dict)
@dataclass
class ConstraintAnalysis:
"""Results of analyzing an assembly's constraint system."""
# Pebble game (combinatorial) results
combinatorial_dof: int
combinatorial_internal_dof: int
combinatorial_redundant: int
combinatorial_classification: str
per_edge_results: list[dict[str, Any]]
# Numerical (Jacobian) results
jacobian_rank: int
jacobian_nullity: int # = 6n - rank = total DOF
jacobian_internal_dof: int # = nullity - 6
numerically_dependent: list[int]
# Combined
geometric_degeneracies: int # = combinatorial_independent - jacobian_rank
is_rigid: bool
is_minimally_rigid: bool

View File

View File

View File

47
solver/mates/__init__.py Normal file
View File

@@ -0,0 +1,47 @@
"""Mate-level constraint types for assembly analysis."""
from solver.mates.conversion import (
MateAnalysisResult,
analyze_mate_assembly,
convert_mates_to_joints,
)
from solver.mates.generator import (
SyntheticMateGenerator,
generate_mate_training_batch,
)
from solver.mates.labeling import (
MateAssemblyLabels,
MateLabel,
label_mate_assembly,
)
from solver.mates.patterns import (
JointPattern,
PatternMatch,
recognize_patterns,
)
from solver.mates.primitives import (
GeometryRef,
GeometryType,
Mate,
MateType,
dof_removed,
)
__all__ = [
"GeometryRef",
"GeometryType",
"JointPattern",
"Mate",
"MateAnalysisResult",
"MateAssemblyLabels",
"MateLabel",
"MateType",
"PatternMatch",
"SyntheticMateGenerator",
"analyze_mate_assembly",
"convert_mates_to_joints",
"dof_removed",
"generate_mate_training_batch",
"label_mate_assembly",
"recognize_patterns",
]

276
solver/mates/conversion.py Normal file
View File

@@ -0,0 +1,276 @@
"""Mate-to-joint conversion and assembly analysis.
Bridges the mate-level constraint representation to the existing
joint-based analysis pipeline. Converts recognized mate patterns
to Joint objects, then runs the pebble game and Jacobian analysis,
maintaining bidirectional traceability between mates and joints.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.labeling import AssemblyLabels, label_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
from solver.mates.patterns import PatternMatch, recognize_patterns
if TYPE_CHECKING:
from typing import Any
from solver.mates.primitives import Mate
__all__ = [
"MateAnalysisResult",
"analyze_mate_assembly",
"convert_mates_to_joints",
]
# ---------------------------------------------------------------------------
# Result dataclass
# ---------------------------------------------------------------------------
@dataclass
class MateAnalysisResult:
"""Combined result of mate-based assembly analysis.
Attributes:
patterns: Recognized joint patterns from mate grouping.
joints: Joint objects produced by conversion.
mate_to_joint: Mapping from mate_id to list of joint_ids.
joint_to_mates: Mapping from joint_id to list of mate_ids.
analysis: Constraint analysis from pebble game + Jacobian.
labels: Full ground truth labels from label_assembly.
"""
patterns: list[PatternMatch]
joints: list[Joint]
mate_to_joint: dict[int, list[int]] = field(default_factory=dict)
joint_to_mates: dict[int, list[int]] = field(default_factory=dict)
analysis: ConstraintAnalysis | None = None
labels: AssemblyLabels | None = None
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"patterns": [p.to_dict() for p in self.patterns],
"joints": [
{
"joint_id": j.joint_id,
"body_a": j.body_a,
"body_b": j.body_b,
"joint_type": j.joint_type.name,
}
for j in self.joints
],
"mate_to_joint": self.mate_to_joint,
"joint_to_mates": self.joint_to_mates,
"labels": self.labels.to_dict() if self.labels else None,
}
# ---------------------------------------------------------------------------
# Pattern-to-JointType mapping
# ---------------------------------------------------------------------------
# Maps (JointPattern value) to JointType for known patterns.
# Used by convert_mates_to_joints when a full pattern is recognized.
_PATTERN_JOINT_MAP: dict[str, JointType] = {
"hinge": JointType.REVOLUTE,
"slider": JointType.SLIDER,
"cylinder": JointType.CYLINDRICAL,
"ball": JointType.BALL,
"planar": JointType.PLANAR,
"fixed": JointType.FIXED,
}
# Fallback mapping for individual mate types when no pattern is recognized.
_MATE_JOINT_FALLBACK: dict[str, JointType] = {
"COINCIDENT": JointType.PLANAR,
"CONCENTRIC": JointType.CYLINDRICAL,
"PARALLEL": JointType.PARALLEL,
"PERPENDICULAR": JointType.PERPENDICULAR,
"TANGENT": JointType.DISTANCE,
"DISTANCE": JointType.DISTANCE,
"ANGLE": JointType.PERPENDICULAR,
"LOCK": JointType.FIXED,
}
# ---------------------------------------------------------------------------
# Conversion
# ---------------------------------------------------------------------------
def _compute_joint_params(
pattern: PatternMatch,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Extract anchor and axis from pattern mates.
Returns:
(anchor_a, anchor_b, axis)
"""
anchor_a = np.zeros(3)
anchor_b = np.zeros(3)
axis = np.array([0.0, 0.0, 1.0])
for mate in pattern.mates:
ref_a = mate.ref_a
ref_b = mate.ref_b
anchor_a = ref_a.origin.copy()
anchor_b = ref_b.origin.copy()
if ref_a.direction is not None:
axis = ref_a.direction.copy()
break
return anchor_a, anchor_b, axis
def _convert_single_mate(
mate: Mate,
joint_id: int,
) -> Joint:
"""Convert a single unmatched mate to a Joint."""
joint_type = _MATE_JOINT_FALLBACK.get(mate.mate_type.name, JointType.DISTANCE)
anchor_a = mate.ref_a.origin.copy()
anchor_b = mate.ref_b.origin.copy()
axis = np.array([0.0, 0.0, 1.0])
if mate.ref_a.direction is not None:
axis = mate.ref_a.direction.copy()
return Joint(
joint_id=joint_id,
body_a=mate.ref_a.body_id,
body_b=mate.ref_b.body_id,
joint_type=joint_type,
anchor_a=anchor_a,
anchor_b=anchor_b,
axis=axis,
)
def convert_mates_to_joints(
mates: list[Mate],
bodies: list[RigidBody] | None = None,
) -> tuple[list[Joint], dict[int, list[int]], dict[int, list[int]]]:
"""Convert mates to Joint objects via pattern recognition.
For each body pair:
- If mates form a recognized pattern, emit the equivalent joint.
- Otherwise, emit individual joints for each unmatched mate.
Args:
mates: Mate constraints to convert.
bodies: Optional body list (unused currently, reserved for
future geometry lookups).
Returns:
(joints, mate_to_joint, joint_to_mates) tuple.
"""
if not mates:
return [], {}, {}
patterns = recognize_patterns(mates)
joints: list[Joint] = []
mate_to_joint: dict[int, list[int]] = {}
joint_to_mates: dict[int, list[int]] = {}
# Track which mates have been consumed by full-confidence patterns
consumed_mate_ids: set[int] = set()
next_joint_id = 0
# First pass: emit joints for full-confidence patterns
for pattern in patterns:
if pattern.confidence < 1.0:
continue
if pattern.pattern.value not in _PATTERN_JOINT_MAP:
continue
# Check if any of these mates were already consumed
mate_ids = [m.mate_id for m in pattern.mates]
if any(mid in consumed_mate_ids for mid in mate_ids):
continue
joint_type = _PATTERN_JOINT_MAP[pattern.pattern.value]
anchor_a, anchor_b, axis = _compute_joint_params(pattern)
joint = Joint(
joint_id=next_joint_id,
body_a=pattern.body_a,
body_b=pattern.body_b,
joint_type=joint_type,
anchor_a=anchor_a,
anchor_b=anchor_b,
axis=axis,
)
joints.append(joint)
joint_to_mates[next_joint_id] = mate_ids
for mid in mate_ids:
mate_to_joint.setdefault(mid, []).append(next_joint_id)
consumed_mate_ids.add(mid)
next_joint_id += 1
# Second pass: emit individual joints for unconsumed mates
for mate in mates:
if mate.mate_id in consumed_mate_ids:
continue
joint = _convert_single_mate(mate, next_joint_id)
joints.append(joint)
joint_to_mates[next_joint_id] = [mate.mate_id]
mate_to_joint.setdefault(mate.mate_id, []).append(next_joint_id)
next_joint_id += 1
return joints, mate_to_joint, joint_to_mates
# ---------------------------------------------------------------------------
# Full analysis pipeline
# ---------------------------------------------------------------------------
def analyze_mate_assembly(
bodies: list[RigidBody],
mates: list[Mate],
ground_body: int | None = None,
) -> MateAnalysisResult:
"""Run the full analysis pipeline on a mate-based assembly.
Orchestrates: recognize_patterns -> convert_mates_to_joints ->
label_assembly, returning a combined result with full traceability.
Args:
bodies: Rigid bodies in the assembly.
mates: Mate constraints between the bodies.
ground_body: If set, this body is fixed to the world.
Returns:
MateAnalysisResult with patterns, joints, mappings, and labels.
"""
patterns = recognize_patterns(mates)
joints, mate_to_joint, joint_to_mates = convert_mates_to_joints(mates, bodies)
labels = label_assembly(bodies, joints, ground_body)
return MateAnalysisResult(
patterns=patterns,
joints=joints,
mate_to_joint=mate_to_joint,
joint_to_mates=joint_to_mates,
analysis=labels.analysis,
labels=labels,
)

315
solver/mates/generator.py Normal file
View File

@@ -0,0 +1,315 @@
"""Mate-based synthetic assembly generator.
Wraps SyntheticAssemblyGenerator to produce mate-level training data.
Generates joint-based assemblies via the existing generator, then
reverse-maps joints to plausible mate combinations. Supports noise
injection (redundant, missing, incompatible mates) for robust training.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import numpy as np
from solver.datagen.generator import SyntheticAssemblyGenerator
from solver.datagen.types import Joint, JointType, RigidBody
from solver.mates.conversion import MateAnalysisResult, analyze_mate_assembly
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
if TYPE_CHECKING:
from typing import Any
__all__ = [
"SyntheticMateGenerator",
"generate_mate_training_batch",
]
# ---------------------------------------------------------------------------
# Reverse mapping: JointType -> list of (MateType, geom_a, geom_b) combos
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class _MateSpec:
"""Specification for a mate to generate from a joint."""
mate_type: MateType
geom_a: GeometryType
geom_b: GeometryType
_JOINT_TO_MATES: dict[JointType, list[_MateSpec]] = {
JointType.REVOLUTE: [
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
],
JointType.CYLINDRICAL: [
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
],
JointType.BALL: [
_MateSpec(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),
],
JointType.FIXED: [
_MateSpec(MateType.LOCK, GeometryType.FACE, GeometryType.FACE),
],
JointType.SLIDER: [
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
_MateSpec(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
],
JointType.PLANAR: [
_MateSpec(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),
],
}
# ---------------------------------------------------------------------------
# Generator
# ---------------------------------------------------------------------------
class SyntheticMateGenerator:
"""Generates mate-based assemblies for training data.
Wraps SyntheticAssemblyGenerator to produce joint-based assemblies,
then reverse-maps each joint to a plausible set of mate constraints.
Args:
seed: Random seed for reproducibility.
redundant_prob: Probability of injecting a redundant mate per joint.
missing_prob: Probability of dropping a mate from a multi-mate pattern.
incompatible_prob: Probability of injecting a mate with wrong geometry.
"""
def __init__(
self,
seed: int = 42,
*,
redundant_prob: float = 0.0,
missing_prob: float = 0.0,
incompatible_prob: float = 0.0,
) -> None:
self._joint_gen = SyntheticAssemblyGenerator(seed=seed)
self._rng = np.random.default_rng(seed)
self.redundant_prob = redundant_prob
self.missing_prob = missing_prob
self.incompatible_prob = incompatible_prob
def _make_geometry_ref(
self,
body_id: int,
geom_type: GeometryType,
joint: Joint,
*,
is_ref_a: bool = True,
) -> GeometryRef:
"""Create a GeometryRef from joint geometry.
Uses joint anchor, axis, and body_id to produce a ref
with realistic geometry for the given type.
"""
origin = joint.anchor_a if is_ref_a else joint.anchor_b
direction: np.ndarray | None = None
if geom_type in {GeometryType.AXIS, GeometryType.PLANE, GeometryType.FACE}:
direction = joint.axis.copy()
geom_id = f"{geom_type.value.capitalize()}001"
return GeometryRef(
body_id=body_id,
geometry_type=geom_type,
geometry_id=geom_id,
origin=origin.copy(),
direction=direction,
)
def _reverse_map_joint(
self,
joint: Joint,
next_mate_id: int,
) -> list[Mate]:
"""Convert a joint to its mate representation."""
specs = _JOINT_TO_MATES.get(joint.joint_type, [])
if not specs:
# Fallback: emit a single DISTANCE mate
specs = [_MateSpec(MateType.DISTANCE, GeometryType.POINT, GeometryType.POINT)]
mates: list[Mate] = []
for spec in specs:
ref_a = self._make_geometry_ref(joint.body_a, spec.geom_a, joint, is_ref_a=True)
ref_b = self._make_geometry_ref(joint.body_b, spec.geom_b, joint, is_ref_a=False)
mates.append(
Mate(
mate_id=next_mate_id + len(mates),
mate_type=spec.mate_type,
ref_a=ref_a,
ref_b=ref_b,
)
)
return mates
def _inject_noise(
self,
mates: list[Mate],
next_mate_id: int,
) -> list[Mate]:
"""Apply noise injection to the mate list.
Modifies the list in-place and may add new mates.
Returns the (possibly extended) list.
"""
result = list(mates)
extra: list[Mate] = []
for mate in mates:
# Redundant: duplicate a mate
if self._rng.random() < self.redundant_prob:
dup = Mate(
mate_id=next_mate_id + len(extra),
mate_type=mate.mate_type,
ref_a=mate.ref_a,
ref_b=mate.ref_b,
value=mate.value,
tolerance=mate.tolerance,
)
extra.append(dup)
# Incompatible: wrong geometry type
if self._rng.random() < self.incompatible_prob:
bad_geom = GeometryType.POINT
bad_ref = GeometryRef(
body_id=mate.ref_a.body_id,
geometry_type=bad_geom,
geometry_id="BadGeom001",
origin=mate.ref_a.origin.copy(),
direction=None,
)
extra.append(
Mate(
mate_id=next_mate_id + len(extra),
mate_type=MateType.CONCENTRIC,
ref_a=bad_ref,
ref_b=mate.ref_b,
)
)
result.extend(extra)
# Missing: drop mates from multi-mate patterns (only if > 1 mate
# for same body pair)
if self.missing_prob > 0:
filtered: list[Mate] = []
for mate in result:
if self._rng.random() < self.missing_prob:
continue
filtered.append(mate)
# Ensure at least one mate remains
if not filtered and result:
filtered = [result[0]]
result = filtered
return result
def generate(
self,
n_bodies: int = 4,
*,
grounded: bool = False,
) -> tuple[list[RigidBody], list[Mate], MateAnalysisResult]:
"""Generate a mate-based assembly.
Args:
n_bodies: Number of rigid bodies.
grounded: Whether to ground the first body.
Returns:
(bodies, mates, analysis_result) tuple.
"""
bodies, joints, _analysis = self._joint_gen.generate_chain_assembly(
n_bodies,
joint_type=JointType.REVOLUTE,
grounded=grounded,
)
mates: list[Mate] = []
next_id = 0
for joint in joints:
joint_mates = self._reverse_map_joint(joint, next_id)
mates.extend(joint_mates)
next_id += len(joint_mates)
# Apply noise
mates = self._inject_noise(mates, next_id)
ground_body = bodies[0].body_id if grounded else None
result = analyze_mate_assembly(bodies, mates, ground_body)
return bodies, mates, result
# ---------------------------------------------------------------------------
# Batch generation
# ---------------------------------------------------------------------------
def generate_mate_training_batch(
batch_size: int = 100,
n_bodies_range: tuple[int, int] = (3, 8),
seed: int = 42,
*,
redundant_prob: float = 0.0,
missing_prob: float = 0.0,
incompatible_prob: float = 0.0,
grounded_ratio: float = 1.0,
) -> list[dict[str, Any]]:
"""Produce a batch of mate-level training examples.
Args:
batch_size: Number of assemblies to generate.
n_bodies_range: (min, max_exclusive) body count.
seed: Random seed.
redundant_prob: Probability of redundant mate injection.
missing_prob: Probability of missing mate injection.
incompatible_prob: Probability of incompatible mate injection.
grounded_ratio: Fraction of assemblies that are grounded.
Returns:
List of dicts with bodies, mates, patterns, and labels.
"""
rng = np.random.default_rng(seed)
examples: list[dict[str, Any]] = []
for i in range(batch_size):
gen = SyntheticMateGenerator(
seed=seed + i,
redundant_prob=redundant_prob,
missing_prob=missing_prob,
incompatible_prob=incompatible_prob,
)
n = int(rng.integers(*n_bodies_range))
grounded = bool(rng.random() < grounded_ratio)
bodies, mates, result = gen.generate(n, grounded=grounded)
examples.append(
{
"bodies": [
{
"body_id": b.body_id,
"position": b.position.tolist(),
}
for b in bodies
],
"mates": [m.to_dict() for m in mates],
"patterns": [p.to_dict() for p in result.patterns],
"labels": result.labels.to_dict() if result.labels else None,
"n_bodies": len(bodies),
"n_mates": len(mates),
"n_joints": len(result.joints),
}
)
return examples

224
solver/mates/labeling.py Normal file
View File

@@ -0,0 +1,224 @@
"""Mate-level ground truth labels for assembly analysis.
Back-attributes joint-level independence results to originating mates
via the mate-to-joint mapping from conversion.py. Produces per-mate
labels indicating whether each mate is independent, redundant, or
degenerate.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
from solver.mates.conversion import analyze_mate_assembly
if TYPE_CHECKING:
from typing import Any
from solver.datagen.labeling import AssemblyLabel
from solver.datagen.types import ConstraintAnalysis, RigidBody
from solver.mates.conversion import MateAnalysisResult
from solver.mates.patterns import JointPattern, PatternMatch
from solver.mates.primitives import Mate
__all__ = [
"MateAssemblyLabels",
"MateLabel",
"label_mate_assembly",
]
# ---------------------------------------------------------------------------
# Label dataclasses
# ---------------------------------------------------------------------------
@dataclass
class MateLabel:
"""Per-mate ground truth label.
Attributes:
mate_id: The mate this label refers to.
is_independent: Contributes non-redundant DOF removal.
is_redundant: Fully redundant (removable without DOF change).
is_degenerate: Combinatorially independent but geometrically dependent.
pattern: Which joint pattern this mate belongs to, if any.
issue: Detected issue type, if any.
"""
mate_id: int
is_independent: bool = True
is_redundant: bool = False
is_degenerate: bool = False
pattern: JointPattern | None = None
issue: str | None = None
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"mate_id": self.mate_id,
"is_independent": self.is_independent,
"is_redundant": self.is_redundant,
"is_degenerate": self.is_degenerate,
"pattern": self.pattern.value if self.pattern else None,
"issue": self.issue,
}
@dataclass
class MateAssemblyLabels:
"""Complete mate-level ground truth labels for an assembly.
Attributes:
per_mate: Per-mate labels.
patterns: Recognized joint patterns.
assembly: Assembly-wide summary label.
analysis: Constraint analysis from pebble game + Jacobian.
"""
per_mate: list[MateLabel]
patterns: list[PatternMatch]
assembly: AssemblyLabel
analysis: ConstraintAnalysis
mate_analysis: MateAnalysisResult | None = None
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"per_mate": [ml.to_dict() for ml in self.per_mate],
"patterns": [p.to_dict() for p in self.patterns],
"assembly": {
"classification": self.assembly.classification,
"total_dof": self.assembly.total_dof,
"redundant_count": self.assembly.redundant_count,
"is_rigid": self.assembly.is_rigid,
"is_minimally_rigid": self.assembly.is_minimally_rigid,
"has_degeneracy": self.assembly.has_degeneracy,
},
}
# ---------------------------------------------------------------------------
# Labeling logic
# ---------------------------------------------------------------------------
def _build_mate_pattern_map(
patterns: list[PatternMatch],
) -> dict[int, JointPattern]:
"""Map mate_ids to the pattern they belong to (best match)."""
result: dict[int, JointPattern] = {}
# Sort by confidence descending so best matches win
sorted_patterns = sorted(patterns, key=lambda p: -p.confidence)
for pm in sorted_patterns:
if pm.confidence < 1.0:
continue
for mate in pm.mates:
if mate.mate_id not in result:
result[mate.mate_id] = pm.pattern
return result
def label_mate_assembly(
bodies: list[RigidBody],
mates: list[Mate],
ground_body: int | None = None,
) -> MateAssemblyLabels:
"""Produce mate-level ground truth labels for an assembly.
Runs analyze_mate_assembly() internally, then back-attributes
joint-level independence to originating mates via the mate_to_joint
mapping.
A mate is:
- **redundant** if ALL joints it contributes to are fully redundant
- **degenerate** if any joint it contributes to is geometrically
dependent but combinatorially independent
- **independent** otherwise
Args:
bodies: Rigid bodies in the assembly.
mates: Mate constraints between the bodies.
ground_body: If set, this body is fixed to the world.
Returns:
MateAssemblyLabels with per-mate labels and assembly summary.
"""
mate_result = analyze_mate_assembly(bodies, mates, ground_body)
# Build per-joint redundancy from labels
joint_redundant: dict[int, bool] = {}
joint_degenerate: dict[int, bool] = {}
if mate_result.labels is not None:
for jl in mate_result.labels.per_joint:
# A joint is fully redundant if all its constraints are redundant
joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0
# Joint is degenerate if it has more independent constraints
# than Jacobian rank would suggest (geometric degeneracy)
joint_degenerate[jl.joint_id] = False
# Check for geometric degeneracy via per-constraint labels
for cl in mate_result.labels.per_constraint:
if cl.pebble_independent and not cl.jacobian_independent:
joint_degenerate[cl.joint_id] = True
# Build pattern membership map
pattern_map = _build_mate_pattern_map(mate_result.patterns)
# Back-attribute to mates
per_mate: list[MateLabel] = []
for mate in mates:
mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, [])
if not mate_joint_ids:
# Mate wasn't converted to any joint (shouldn't happen, but safe)
per_mate.append(
MateLabel(
mate_id=mate.mate_id,
is_independent=False,
is_redundant=True,
issue="unmapped",
)
)
continue
# Redundant if ALL contributed joints are redundant
all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids)
# Degenerate if ANY contributed joint is degenerate
any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids)
is_independent = not all_redundant
pattern = pattern_map.get(mate.mate_id)
# Determine issue string
issue: str | None = None
if all_redundant:
issue = "redundant"
elif any_degenerate:
issue = "degenerate"
per_mate.append(
MateLabel(
mate_id=mate.mate_id,
is_independent=is_independent,
is_redundant=all_redundant,
is_degenerate=any_degenerate,
pattern=pattern,
issue=issue,
)
)
# Assembly label
assert mate_result.labels is not None
assembly_label = mate_result.labels.assembly
return MateAssemblyLabels(
per_mate=per_mate,
patterns=mate_result.patterns,
assembly=assembly_label,
analysis=mate_result.labels.analysis,
mate_analysis=mate_result,
)

284
solver/mates/patterns.py Normal file
View File

@@ -0,0 +1,284 @@
"""Joint pattern recognition from mate combinations.
Groups mates by body pair and matches them against canonical joint
patterns (hinge, slider, ball, etc.). Each pattern is a known
combination of mate types that together constrain motion equivalently
to a single mechanical joint.
"""
from __future__ import annotations
import enum
from collections import defaultdict
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
from solver.datagen.types import JointType
from solver.mates.primitives import GeometryType, Mate, MateType
if TYPE_CHECKING:
from typing import Any
__all__ = [
"JointPattern",
"PatternMatch",
"recognize_patterns",
]
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class JointPattern(enum.Enum):
"""Canonical joint patterns formed by mate combinations."""
HINGE = "hinge"
SLIDER = "slider"
CYLINDER = "cylinder"
BALL = "ball"
PLANAR = "planar"
FIXED = "fixed"
GEAR = "gear"
RACK_PINION = "rack_pinion"
UNKNOWN = "unknown"
# ---------------------------------------------------------------------------
# Pattern match result
# ---------------------------------------------------------------------------
@dataclass
class PatternMatch:
"""Result of matching a group of mates to a joint pattern.
Attributes:
pattern: The identified joint pattern.
mates: The mates that form this pattern.
body_a: First body in the pair.
body_b: Second body in the pair.
confidence: How well the mates match the canonical pattern (0-1).
equivalent_joint_type: The JointType this pattern maps to.
missing_mates: Descriptions of mates absent for a full match.
"""
pattern: JointPattern
mates: list[Mate]
body_a: int
body_b: int
confidence: float
equivalent_joint_type: JointType
missing_mates: list[str] = field(default_factory=list)
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"pattern": self.pattern.value,
"body_a": self.body_a,
"body_b": self.body_b,
"confidence": self.confidence,
"equivalent_joint_type": self.equivalent_joint_type.name,
"mate_ids": [m.mate_id for m in self.mates],
"missing_mates": self.missing_mates,
}
# ---------------------------------------------------------------------------
# Pattern rules (data-driven)
# ---------------------------------------------------------------------------
@dataclass(frozen=True)
class _MateRequirement:
"""A single mate requirement within a pattern rule."""
mate_type: MateType
geometry_a: GeometryType | None = None
geometry_b: GeometryType | None = None
@dataclass(frozen=True)
class _PatternRule:
"""Defines a canonical pattern as a set of required mates."""
pattern: JointPattern
joint_type: JointType
required: tuple[_MateRequirement, ...]
description: str = ""
_PATTERN_RULES: list[_PatternRule] = [
_PatternRule(
pattern=JointPattern.HINGE,
joint_type=JointType.REVOLUTE,
required=(
_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
),
description="Concentric axes + coincident plane",
),
_PatternRule(
pattern=JointPattern.SLIDER,
joint_type=JointType.SLIDER,
required=(
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
_MateRequirement(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
),
description="Coincident plane + parallel axis",
),
_PatternRule(
pattern=JointPattern.CYLINDER,
joint_type=JointType.CYLINDRICAL,
required=(_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),),
description="Concentric axes only",
),
_PatternRule(
pattern=JointPattern.BALL,
joint_type=JointType.BALL,
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),),
description="Coincident points",
),
_PatternRule(
pattern=JointPattern.PLANAR,
joint_type=JointType.PLANAR,
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),),
description="Coincident faces",
),
_PatternRule(
pattern=JointPattern.PLANAR,
joint_type=JointType.PLANAR,
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),),
description="Coincident planes (alternate planar)",
),
_PatternRule(
pattern=JointPattern.FIXED,
joint_type=JointType.FIXED,
required=(_MateRequirement(MateType.LOCK),),
description="Lock mate",
),
]
# ---------------------------------------------------------------------------
# Matching logic
# ---------------------------------------------------------------------------
def _mate_matches_requirement(mate: Mate, req: _MateRequirement) -> bool:
"""Check if a mate satisfies a requirement."""
if mate.mate_type is not req.mate_type:
return False
if req.geometry_a is not None and mate.ref_a.geometry_type is not req.geometry_a:
return False
return not (req.geometry_b is not None and mate.ref_b.geometry_type is not req.geometry_b)
def _try_match_rule(
rule: _PatternRule,
mates: list[Mate],
) -> tuple[float, list[Mate], list[str]]:
"""Try to match a rule against a group of mates.
Returns:
(confidence, matched_mates, missing_descriptions)
"""
matched: list[Mate] = []
missing: list[str] = []
for req in rule.required:
found = False
for mate in mates:
if mate in matched:
continue
if _mate_matches_requirement(mate, req):
matched.append(mate)
found = True
break
if not found:
geom_desc = ""
if req.geometry_a is not None:
geom_b = req.geometry_b.value if req.geometry_b else "*"
geom_desc = f" ({req.geometry_a.value}-{geom_b})"
missing.append(f"{req.mate_type.name}{geom_desc}")
total_required = len(rule.required)
if total_required == 0:
return 0.0, [], []
matched_count = len(matched)
confidence = matched_count / total_required
return confidence, matched, missing
def _normalize_body_pair(body_a: int, body_b: int) -> tuple[int, int]:
"""Normalize a body pair so the smaller ID comes first."""
return (min(body_a, body_b), max(body_a, body_b))
def recognize_patterns(mates: list[Mate]) -> list[PatternMatch]:
"""Identify joint patterns from a list of mates.
Groups mates by body pair, then checks each group against
canonical pattern rules. Returns matches sorted by confidence
descending.
Args:
mates: List of mate constraints to analyze.
Returns:
List of PatternMatch results, highest confidence first.
"""
if not mates:
return []
# Group mates by normalized body pair
groups: dict[tuple[int, int], list[Mate]] = defaultdict(list)
for mate in mates:
pair = _normalize_body_pair(mate.ref_a.body_id, mate.ref_b.body_id)
groups[pair].append(mate)
results: list[PatternMatch] = []
for (body_a, body_b), group_mates in groups.items():
group_matches: list[PatternMatch] = []
for rule in _PATTERN_RULES:
confidence, matched, missing = _try_match_rule(rule, group_mates)
if confidence > 0:
group_matches.append(
PatternMatch(
pattern=rule.pattern,
mates=matched if matched else group_mates,
body_a=body_a,
body_b=body_b,
confidence=confidence,
equivalent_joint_type=rule.joint_type,
missing_mates=missing,
)
)
if group_matches:
# Sort by confidence descending, prefer more-specific patterns
group_matches.sort(key=lambda m: (-m.confidence, -len(m.mates)))
results.extend(group_matches)
else:
# No pattern matched at all
results.append(
PatternMatch(
pattern=JointPattern.UNKNOWN,
mates=group_mates,
body_a=body_a,
body_b=body_b,
confidence=0.0,
equivalent_joint_type=JointType.DISTANCE,
missing_mates=[],
)
)
# Global sort by confidence descending
results.sort(key=lambda m: -m.confidence)
return results

279
solver/mates/primitives.py Normal file
View File

@@ -0,0 +1,279 @@
"""Mate type definitions and geometry references for assembly constraints.
Mates are the user-facing constraint primitives in CAD (e.g. SolidWorks-style
Coincident, Concentric, Parallel). Each mate references geometry on two bodies
and removes a context-dependent number of degrees of freedom.
"""
from __future__ import annotations
import enum
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import numpy as np
if TYPE_CHECKING:
from typing import Any
__all__ = [
"GeometryRef",
"GeometryType",
"Mate",
"MateType",
"dof_removed",
]
# ---------------------------------------------------------------------------
# Enums
# ---------------------------------------------------------------------------
class MateType(enum.Enum):
"""CAD mate types with default DOF-removal counts.
Values are ``(ordinal, default_dof)`` tuples so that mate types
sharing the same DOF count remain distinct enum members. Use the
:attr:`default_dof` property to get the scalar constraint count.
The actual DOF removed can be context-dependent (e.g. COINCIDENT
removes 3 DOF for face-face but only 1 for face-point). Use
:func:`dof_removed` for the context-aware count.
"""
COINCIDENT = (0, 3)
CONCENTRIC = (1, 2)
PARALLEL = (2, 2)
PERPENDICULAR = (3, 1)
TANGENT = (4, 1)
DISTANCE = (5, 1)
ANGLE = (6, 1)
LOCK = (7, 6)
@property
def default_dof(self) -> int:
"""Default number of DOF removed by this mate type."""
return self.value[1]
class GeometryType(enum.Enum):
"""Types of geometric references used by mates."""
FACE = "face"
EDGE = "edge"
POINT = "point"
AXIS = "axis"
PLANE = "plane"
# Geometry types that require a direction vector.
_DIRECTIONAL_TYPES = frozenset(
{
GeometryType.FACE,
GeometryType.AXIS,
GeometryType.PLANE,
}
)
# ---------------------------------------------------------------------------
# Dataclasses
# ---------------------------------------------------------------------------
@dataclass
class GeometryRef:
"""A reference to a specific geometric entity on a body.
Attributes:
body_id: Index of the body this geometry belongs to.
geometry_type: What kind of geometry (face, edge, etc.).
geometry_id: CAD identifier string (e.g. ``"Face001"``).
origin: 3D position of the geometry reference point.
direction: Unit direction vector. Required for FACE, AXIS, PLANE;
``None`` for POINT.
"""
body_id: int
geometry_type: GeometryType
geometry_id: str
origin: np.ndarray = field(default_factory=lambda: np.zeros(3))
direction: np.ndarray | None = None
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"body_id": self.body_id,
"geometry_type": self.geometry_type.value,
"geometry_id": self.geometry_id,
"origin": self.origin.tolist(),
"direction": self.direction.tolist() if self.direction is not None else None,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> GeometryRef:
"""Construct from a dict produced by :meth:`to_dict`."""
direction_raw = data.get("direction")
return cls(
body_id=data["body_id"],
geometry_type=GeometryType(data["geometry_type"]),
geometry_id=data["geometry_id"],
origin=np.asarray(data["origin"], dtype=np.float64),
direction=(
np.asarray(direction_raw, dtype=np.float64) if direction_raw is not None else None
),
)
@dataclass
class Mate:
"""A mate constraint between geometry on two bodies.
Attributes:
mate_id: Unique identifier for this mate.
mate_type: The type of constraint (Coincident, Concentric, etc.).
ref_a: Geometry reference on the first body.
ref_b: Geometry reference on the second body.
value: Scalar parameter for DISTANCE and ANGLE mates (0 otherwise).
tolerance: Numeric tolerance for constraint satisfaction.
"""
mate_id: int
mate_type: MateType
ref_a: GeometryRef
ref_b: GeometryRef
value: float = 0.0
tolerance: float = 1e-6
def validate(self) -> None:
"""Raise ``ValueError`` if this mate has incompatible geometry.
Checks:
- Self-mate (both refs on same body)
- CONCENTRIC requires AXIS geometry on both refs
- PARALLEL requires directional geometry (not POINT)
- TANGENT requires surface geometry (FACE or EDGE)
- Directional geometry types must have a direction vector
"""
if self.ref_a.body_id == self.ref_b.body_id:
msg = f"Self-mate: ref_a and ref_b both reference body {self.ref_a.body_id}"
raise ValueError(msg)
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
if ref.geometry_type in _DIRECTIONAL_TYPES and ref.direction is None:
msg = (
f"{label}: geometry type {ref.geometry_type.value} requires a direction vector"
)
raise ValueError(msg)
if self.mate_type is MateType.CONCENTRIC:
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
if ref.geometry_type is not GeometryType.AXIS:
msg = (
f"CONCENTRIC mate requires AXIS geometry, "
f"got {ref.geometry_type.value} on {label}"
)
raise ValueError(msg)
if self.mate_type is MateType.PARALLEL:
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
if ref.geometry_type is GeometryType.POINT:
msg = f"PARALLEL mate requires directional geometry, got POINT on {label}"
raise ValueError(msg)
if self.mate_type is MateType.TANGENT:
_surface = frozenset({GeometryType.FACE, GeometryType.EDGE})
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
if ref.geometry_type not in _surface:
msg = (
f"TANGENT mate requires surface geometry "
f"(FACE or EDGE), got {ref.geometry_type.value} "
f"on {label}"
)
raise ValueError(msg)
def to_dict(self) -> dict[str, Any]:
"""Return a JSON-serializable dict."""
return {
"mate_id": self.mate_id,
"mate_type": self.mate_type.name,
"ref_a": self.ref_a.to_dict(),
"ref_b": self.ref_b.to_dict(),
"value": self.value,
"tolerance": self.tolerance,
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> Mate:
"""Construct from a dict produced by :meth:`to_dict`."""
return cls(
mate_id=data["mate_id"],
mate_type=MateType[data["mate_type"]],
ref_a=GeometryRef.from_dict(data["ref_a"]),
ref_b=GeometryRef.from_dict(data["ref_b"]),
value=data.get("value", 0.0),
tolerance=data.get("tolerance", 1e-6),
)
# ---------------------------------------------------------------------------
# Context-dependent DOF removal
# ---------------------------------------------------------------------------
# Lookup table: (MateType, ref_a GeometryType, ref_b GeometryType) -> DOF removed.
# Entries with None match any geometry type for that position.
_DOF_TABLE: dict[tuple[MateType, GeometryType | None, GeometryType | None], int] = {
# COINCIDENT — context-dependent
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE): 3,
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT): 3,
(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE): 3,
(MateType.COINCIDENT, GeometryType.EDGE, GeometryType.EDGE): 2,
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.POINT): 1,
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.FACE): 1,
# CONCENTRIC
(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS): 2,
# PARALLEL
(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS): 2,
(MateType.PARALLEL, GeometryType.FACE, GeometryType.FACE): 2,
(MateType.PARALLEL, GeometryType.PLANE, GeometryType.PLANE): 2,
# TANGENT
(MateType.TANGENT, GeometryType.FACE, GeometryType.FACE): 1,
(MateType.TANGENT, GeometryType.FACE, GeometryType.EDGE): 1,
(MateType.TANGENT, GeometryType.EDGE, GeometryType.FACE): 1,
# Types where DOF is always the same regardless of geometry
(MateType.PERPENDICULAR, None, None): 1,
(MateType.DISTANCE, None, None): 1,
(MateType.ANGLE, None, None): 1,
(MateType.LOCK, None, None): 6,
}
def dof_removed(
mate_type: MateType,
ref_a: GeometryRef,
ref_b: GeometryRef,
) -> int:
"""Return the number of DOF removed by a mate given its geometry context.
Looks up the exact ``(mate_type, ref_a.geometry_type, ref_b.geometry_type)``
combination first, then falls back to a wildcard ``(mate_type, None, None)``
entry, and finally to :attr:`MateType.default_dof`.
Args:
mate_type: The mate constraint type.
ref_a: Geometry reference on the first body.
ref_b: Geometry reference on the second body.
Returns:
Number of scalar DOF removed by this mate.
"""
key = (mate_type, ref_a.geometry_type, ref_b.geometry_type)
if key in _DOF_TABLE:
return _DOF_TABLE[key]
wildcard = (mate_type, None, None)
if wildcard in _DOF_TABLE:
return _DOF_TABLE[wildcard]
return mate_type.default_dof

View File

View File

0
tests/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,240 @@
"""Tests for solver.datagen.analysis -- combined analysis function."""
from __future__ import annotations
import numpy as np
import pytest
from solver.datagen.analysis import analyze_assembly
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
RigidBody,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _two_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
]
def _triangle_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([1.0, 1.7, 0.0])),
]
# ---------------------------------------------------------------------------
# Scenario 1: Two bodies + revolute (underconstrained, 1 internal DOF)
# ---------------------------------------------------------------------------
class TestTwoBodiesRevolute:
"""Demo scenario 1: two bodies connected by a revolute joint."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _two_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
assert result.jacobian_internal_dof == 1
def test_not_rigid(self, result: ConstraintAnalysis) -> None:
assert not result.is_rigid
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification == "underconstrained"
def test_no_redundant(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_redundant == 0
# ---------------------------------------------------------------------------
# Scenario 2: Two bodies + fixed (well-constrained, 0 internal DOF)
# ---------------------------------------------------------------------------
class TestTwoBodiesFixed:
"""Demo scenario 2: two bodies connected by a fixed joint."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _two_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_internal_dof(self, result: ConstraintAnalysis) -> None:
assert result.jacobian_internal_dof == 0
def test_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_rigid
def test_minimally_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_minimally_rigid
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification == "well-constrained"
# ---------------------------------------------------------------------------
# Scenario 3: Triangle with revolute joints (overconstrained)
# ---------------------------------------------------------------------------
class TestTriangleRevolute:
"""Demo scenario 3: triangle of 3 bodies + 3 revolute joints."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = _triangle_bodies()
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
1,
body_a=1,
body_b=2,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.5, 0.85, 0.0]),
anchor_b=np.array([1.5, 0.85, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
2,
body_a=2,
body_b=0,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([0.5, 0.85, 0.0]),
anchor_b=np.array([0.5, 0.85, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_has_redundant(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_redundant > 0
def test_classification(self, result: ConstraintAnalysis) -> None:
assert result.combinatorial_classification in ("overconstrained", "mixed")
def test_rigid(self, result: ConstraintAnalysis) -> None:
assert result.is_rigid
def test_numerically_dependent(self, result: ConstraintAnalysis) -> None:
assert len(result.numerically_dependent) > 0
# ---------------------------------------------------------------------------
# Scenario 4: Parallel revolute axes (geometric degeneracy)
# ---------------------------------------------------------------------------
class TestParallelRevoluteAxes:
"""Demo scenario 4: parallel revolute axes create geometric degeneracies."""
@pytest.fixture()
def result(self) -> ConstraintAnalysis:
bodies = [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
]
joints = [
Joint(
0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
Joint(
1,
body_a=1,
body_b=2,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([3.0, 0.0, 0.0]),
anchor_b=np.array([3.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
),
]
return analyze_assembly(bodies, joints, ground_body=0)
def test_geometric_degeneracies_detected(self, result: ConstraintAnalysis) -> None:
"""Parallel axes produce at least one geometric degeneracy."""
assert result.geometric_degeneracies > 0
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestNoJoints:
"""Assembly with bodies but no joints."""
def test_all_dof_free(self) -> None:
bodies = _two_bodies()
result = analyze_assembly(bodies, [], ground_body=0)
# Body 1 is completely free (6 DOF), body 0 is grounded
assert result.jacobian_internal_dof > 0
assert not result.is_rigid
def test_ungrounded(self) -> None:
bodies = _two_bodies()
result = analyze_assembly(bodies, [])
assert result.combinatorial_classification == "underconstrained"
class TestReturnType:
"""Verify the return object is a proper ConstraintAnalysis."""
def test_instance(self) -> None:
bodies = _two_bodies()
joints = [Joint(0, 0, 1, JointType.FIXED)]
result = analyze_assembly(bodies, joints)
assert isinstance(result, ConstraintAnalysis)
def test_per_edge_results_populated(self) -> None:
bodies = _two_bodies()
joints = [Joint(0, 0, 1, JointType.REVOLUTE)]
result = analyze_assembly(bodies, joints)
assert len(result.per_edge_results) == 5

View File

@@ -0,0 +1,337 @@
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
from __future__ import annotations
import json
import os
import subprocess
import sys
from typing import TYPE_CHECKING
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
_derive_shard_seed,
_parse_scalar,
parse_simple_yaml,
)
if TYPE_CHECKING:
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# DatasetConfig
# ---------------------------------------------------------------------------
class TestDatasetConfig:
"""DatasetConfig construction and defaults."""
def test_defaults(self) -> None:
cfg = DatasetConfig()
assert cfg.num_assemblies == 100_000
assert cfg.seed == 42
assert cfg.shard_size == 1000
assert cfg.num_workers == 4
def test_from_dict_flat(self) -> None:
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
cfg = DatasetConfig.from_dict(d)
assert cfg.num_assemblies == 500
assert cfg.seed == 123
def test_from_dict_nested_body_count(self) -> None:
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 3
assert cfg.body_count_max == 20
def test_from_dict_flat_body_count(self) -> None:
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 5
assert cfg.body_count_max == 30
def test_from_dict_complexity_distribution(self) -> None:
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
cfg = DatasetConfig.from_dict(d)
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
def test_from_dict_templates(self) -> None:
d: dict[str, Any] = {"templates": ["chain", "tree"]}
cfg = DatasetConfig.from_dict(d)
assert cfg.templates == ["chain", "tree"]
# ---------------------------------------------------------------------------
# Minimal YAML parser
# ---------------------------------------------------------------------------
class TestParseScalar:
"""_parse_scalar handles different value types."""
def test_int(self) -> None:
assert _parse_scalar("42") == 42
def test_float(self) -> None:
assert _parse_scalar("3.14") == 3.14
def test_bool_true(self) -> None:
assert _parse_scalar("true") is True
def test_bool_false(self) -> None:
assert _parse_scalar("false") is False
def test_string(self) -> None:
assert _parse_scalar("hello") == "hello"
def test_inline_comment(self) -> None:
assert _parse_scalar("0.4 # some comment") == 0.4
class TestParseSimpleYaml:
"""parse_simple_yaml handles the synthetic.yaml format."""
def test_flat_scalars(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
result = parse_simple_yaml(str(yaml_file))
assert result["name"] == "test"
assert result["num"] == 42
assert result["ratio"] == 0.5
def test_nested_dict(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
result = parse_simple_yaml(str(yaml_file))
assert result["body_count"] == {"min": 2, "max": 50}
def test_list(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
result = parse_simple_yaml(str(yaml_file))
assert result["templates"] == ["chain", "tree", "loop"]
def test_inline_comments(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
result = parse_simple_yaml(str(yaml_file))
assert result["dist"]["simple"] == 0.4
def test_synthetic_yaml(self) -> None:
"""Parse the actual project config."""
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
assert result["name"] == "synthetic"
assert result["num_assemblies"] == 100000
assert isinstance(result["complexity_distribution"], dict)
assert isinstance(result["templates"], list)
assert result["shard_size"] == 1000
# ---------------------------------------------------------------------------
# Shard seed derivation
# ---------------------------------------------------------------------------
class TestShardSeedDerivation:
"""_derive_shard_seed is deterministic and unique per shard."""
def test_deterministic(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 0)
assert s1 == s2
def test_different_shards(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 1)
assert s1 != s2
def test_different_global_seeds(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(99, 0)
assert s1 != s2
# ---------------------------------------------------------------------------
# DatasetGenerator — small end-to-end tests
# ---------------------------------------------------------------------------
class TestDatasetGenerator:
"""End-to-end tests with small datasets."""
def test_small_generation(self, tmp_path: Path) -> None:
"""Generate 10 examples in a single shard."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
assert shards_dir.exists()
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 1
index_file = tmp_path / "output" / "index.json"
assert index_file.exists()
index = json.loads(index_file.read_text())
assert index["total_assemblies"] == 10
stats_file = tmp_path / "output" / "stats.json"
assert stats_file.exists()
def test_multi_shard(self, tmp_path: Path) -> None:
"""Generate 20 examples across 2 shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 2
def test_resume_skips_completed(self, tmp_path: Path) -> None:
"""Resume skips already-completed shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
# Record shard modification times
shards_dir = tmp_path / "output" / "shards"
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
# Remove stats (simulate incomplete) and re-run
(tmp_path / "output" / "stats.json").unlink()
DatasetGenerator(cfg).run()
# Shards should NOT have been regenerated
for p in shards_dir.glob("shard_*"):
assert p.stat().st_mtime == mtimes[p.name]
# Stats should be regenerated
assert (tmp_path / "output" / "stats.json").exists()
def test_checkpoint_removed(self, tmp_path: Path) -> None:
"""Checkpoint file is cleaned up after completion."""
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / "output"),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
checkpoint = tmp_path / "output" / ".checkpoint.json"
assert not checkpoint.exists()
def test_stats_structure(self, tmp_path: Path) -> None:
"""stats.json has expected top-level keys."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
assert stats["total_examples"] == 10
assert "classification_distribution" in stats
assert "body_count_histogram" in stats
assert "joint_type_distribution" in stats
assert "dof_statistics" in stats
assert "geometric_degeneracy" in stats
assert "rigidity" in stats
def test_index_structure(self, tmp_path: Path) -> None:
"""index.json has expected format."""
cfg = DatasetConfig(
num_assemblies=15,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
index = json.loads((tmp_path / "output" / "index.json").read_text())
assert index["format_version"] == 1
assert index["total_assemblies"] == 15
assert index["total_shards"] == 2
assert "shards" in index
for _name, info in index["shards"].items():
assert "start_id" in info
assert "count" in info
def test_deterministic_output(self, tmp_path: Path) -> None:
"""Same seed produces same results."""
for run_dir in ("run1", "run2"):
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / run_dir),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
assert s1["total_examples"] == s2["total_examples"]
assert s1["classification_distribution"] == s2["classification_distribution"]
# ---------------------------------------------------------------------------
# CLI integration test
# ---------------------------------------------------------------------------
class TestCLI:
"""Run the script via subprocess."""
def test_argparse_mode(self, tmp_path: Path) -> None:
result = subprocess.run(
[
sys.executable,
"scripts/generate_synthetic.py",
"--num-assemblies",
"5",
"--output-dir",
str(tmp_path / "cli_out"),
"--shard-size",
"5",
"--num-workers",
"1",
"--seed",
"42",
],
capture_output=True,
text=True,
cwd="/home/developer",
timeout=120,
env={**os.environ, "PYTHONPATH": "/home/developer"},
)
assert result.returncode == 0, (
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
)
assert (tmp_path / "cli_out" / "index.json").exists()
assert (tmp_path / "cli_out" / "stats.json").exists()

View File

@@ -0,0 +1,682 @@
"""Tests for solver.datagen.generator -- synthetic assembly generation."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.generator import COMPLEXITY_RANGES, SyntheticAssemblyGenerator
from solver.datagen.types import JointType
# ---------------------------------------------------------------------------
# Original generators (chain / rigid / overconstrained)
# ---------------------------------------------------------------------------
class TestChainAssembly:
"""generate_chain_assembly produces valid underconstrained chains."""
def test_returns_three_tuple(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
bodies, joints, _analysis = gen.generate_chain_assembly(4)
assert len(bodies) == 4
assert len(joints) == 3
def test_chain_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_classification == "underconstrained"
def test_chain_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
bodies, _, _ = gen.generate_chain_assembly(5)
ids = [b.body_id for b in bodies]
assert ids == [0, 1, 2, 3, 4]
def test_chain_joint_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(4)
for i, j in enumerate(joints):
assert j.body_a == i
assert j.body_b == i + 1
def test_chain_custom_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(
3,
joint_type=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
class TestRigidAssembly:
"""generate_rigid_assembly produces rigid assemblies."""
def test_rigid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_rigid_assembly(4)
assert analysis.is_rigid
def test_spanning_tree_structure(self) -> None:
"""n bodies should have at least n-1 joints (spanning tree)."""
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_rigid_assembly(5)
assert len(joints) >= len(bodies) - 1
def test_deterministic(self) -> None:
"""Same seed produces same results."""
g1 = SyntheticAssemblyGenerator(seed=99)
g2 = SyntheticAssemblyGenerator(seed=99)
_, j1, a1 = g1.generate_rigid_assembly(4)
_, j2, a2 = g2.generate_rigid_assembly(4)
assert a1.jacobian_rank == a2.jacobian_rank
assert len(j1) == len(j2)
class TestOverconstrainedAssembly:
"""generate_overconstrained_assembly adds redundant constraints."""
def test_has_redundant(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_overconstrained_assembly(
4,
extra_joints=2,
)
assert analysis.combinatorial_redundant > 0
def test_extra_joints_added(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints_base, _ = gen.generate_rigid_assembly(4)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, joints_over, _ = gen2.generate_overconstrained_assembly(
4,
extra_joints=3,
)
# Overconstrained has base joints + extra
assert len(joints_over) > len(joints_base)
# ---------------------------------------------------------------------------
# New topology generators
# ---------------------------------------------------------------------------
class TestTreeAssembly:
"""generate_tree_assembly produces tree-structured assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_tree_assembly(6)
assert analysis.combinatorial_classification == "underconstrained"
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(
10,
branching_factor=2,
)
assert len(bodies) == 10
assert len(joints) == 9
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.BALL, JointType.FIXED]
_, joints, _ = gen.generate_tree_assembly(10, joint_types=types)
used = {j.joint_type for j in joints}
# With 9 joints and 3 types, very likely to use at least 2
assert len(used) >= 2
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(
5,
joint_types=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(7)
assert [b.body_id for b in bodies] == list(range(7))
class TestLoopAssembly:
"""generate_loop_assembly produces closed-loop assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_loop_assembly(5)
assert len(bodies) == 5
assert len(joints) == 5 # n joints for n bodies
def test_has_redundancy(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_loop_assembly(5)
assert analysis.combinatorial_redundant > 0
def test_wrap_around_connectivity(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(4)
edges = {(j.body_a, j.body_b) for j in joints}
assert (0, 1) in edges
assert (1, 2) in edges
assert (2, 3) in edges
assert (3, 0) in edges # wrap-around
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 3"):
gen.generate_loop_assembly(2)
def test_mixed_joint_types(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
types = [JointType.REVOLUTE, JointType.FIXED]
_, joints, _ = gen.generate_loop_assembly(8, joint_types=types)
used = {j.joint_type for j in joints}
assert len(used) >= 2
class TestStarAssembly:
"""generate_star_assembly produces hub-and-spoke assemblies."""
def test_body_and_joint_counts(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_star_assembly(6)
assert len(bodies) == 6
assert len(joints) == 5 # n - 1
def test_all_joints_connect_to_hub(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(6)
for j in joints:
assert j.body_a == 0 or j.body_b == 0
def test_underconstrained(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_star_assembly(5)
assert analysis.combinatorial_classification == "underconstrained"
def test_minimum_bodies_error(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="at least 2"):
gen.generate_star_assembly(1)
def test_hub_at_origin(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(4)
np.testing.assert_array_equal(bodies[0].position, np.zeros(3))
class TestMixedAssembly:
"""generate_mixed_assembly produces tree+loop hybrid assemblies."""
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(
8,
edge_density=0.3,
)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=0.0,
)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=1.5)
with pytest.raises(ValueError, match="must be in"):
gen.generate_mixed_assembly(5, edge_density=-0.1)
def test_no_duplicate_edges(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_mixed_assembly(6, edge_density=0.5)
edges = [frozenset([j.body_a, j.body_b]) for j in joints]
assert len(edges) == len(set(edges))
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=1.0,
)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
class TestAxisStrategy:
"""Axis sampling strategies produce valid unit vectors."""
def test_cardinal_axis_from_six(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
expected = {
(1, 0, 0),
(-1, 0, 0),
(0, 1, 0),
(0, -1, 0),
(0, 0, 1),
(0, 0, -1),
}
assert axes == expected
def test_random_axis_unit_norm(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
for _ in range(50):
axis = gen._sample_axis("random")
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
def test_near_parallel_close_to_base(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
base = np.array([0.0, 0.0, 1.0])
for _ in range(50):
axis = gen._near_parallel_axis(base)
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
assert np.dot(axis, base) > 0.95
def test_sample_axis_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("cardinal")
cardinals = [
np.array(v, dtype=float)
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
]
assert any(np.allclose(axis, c) for c in cardinals)
def test_sample_axis_near_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("near_parallel")
z = np.array([0.0, 0.0, 1.0])
assert np.dot(axis, z) > 0.95
# ---------------------------------------------------------------------------
# Geometric diversity: orientations
# ---------------------------------------------------------------------------
class TestRandomOrientations:
"""Bodies should have non-identity orientations."""
def test_bodies_have_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(5)
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
assert non_identity >= 3
def test_orientations_are_valid_rotations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(6)
for b in bodies:
r = b.orientation
# R^T R == I
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
# det(R) == 1
assert abs(np.linalg.det(r) - 1.0) < 1e-10
def test_all_generators_set_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
# Chain
bodies, _, _ = gen.generate_chain_assembly(3)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Loop
bodies, _, _ = gen.generate_loop_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Mixed
bodies, _, _ = gen.generate_mixed_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# ---------------------------------------------------------------------------
# Geometric diversity: grounded parameter
# ---------------------------------------------------------------------------
class TestGroundedParameter:
"""Grounded parameter controls ground_body in analysis."""
def test_chain_grounded_default(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_dof >= 0
def test_chain_floating(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(
4,
grounded=False,
)
# Floating: 6 trivial DOF not subtracted by ground
assert analysis.combinatorial_dof >= 6
def test_floating_vs_grounded_dof_difference(self) -> None:
gen1 = SyntheticAssemblyGenerator(seed=42)
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
# Floating should have higher DOF due to missing ground constraint
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
@pytest.mark.parametrize(
"gen_method",
[
"generate_chain_assembly",
"generate_rigid_assembly",
"generate_tree_assembly",
"generate_loop_assembly",
"generate_star_assembly",
"generate_mixed_assembly",
],
)
def test_all_generators_accept_grounded(
self,
gen_method: str,
) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
method = getattr(gen, gen_method)
n = 4
# Should not raise
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
method(n, grounded=False)
else:
method(n, grounded=False)
# ---------------------------------------------------------------------------
# Geometric diversity: parallel axis injection
# ---------------------------------------------------------------------------
class TestParallelAxisInjection:
"""parallel_axis_prob causes shared axis direction."""
def test_parallel_axes_similar(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
# Near-parallel: |dot| close to 1
assert abs(np.dot(j.axis, base)) > 0.9
def test_zero_prob_no_forced_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=0.0,
)
base = joints[0].axis
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
# With 5 random axes, extremely unlikely all are parallel
assert min(dots) < 0.95
def test_parallel_on_loop(self) -> None:
"""Parallel axes on a loop assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
def test_parallel_on_star(self) -> None:
"""Parallel axes on a star assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
class TestComplexityTiers:
"""Complexity tier parameter on batch generation."""
def test_simple_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="simple")
lo, hi = COMPLEXITY_RANGES["simple"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_medium_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, complexity_tier="medium")
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_complex_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(3, complexity_tier="complex")
lo, hi = COMPLEXITY_RANGES["complex"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
n_bodies_range=(2, 3),
complexity_tier="medium",
)
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
# ---------------------------------------------------------------------------
# Training batch
# ---------------------------------------------------------------------------
class TestTrainingBatch:
"""generate_training_batch produces well-structured examples."""
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"grounded",
"n_bodies",
"n_joints",
"body_positions",
"body_orientations",
"joints",
"joint_labels",
"labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
"internal_dof",
"geometric_degeneracies",
}
VALID_GEN_TYPES: ClassVar[set[str]] = {
"chain",
"rigid",
"overconstrained",
"tree",
"loop",
"star",
"mixed",
}
def test_batch_size(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20)
assert len(batch) == 20
def test_example_keys(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert set(ex.keys()) == self.EXPECTED_KEYS
def test_example_ids_sequential(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(15)
assert [ex["example_id"] for ex in batch] == list(range(15))
def test_generator_type_valid(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(50)
for ex in batch:
assert ex["generator_type"] in self.VALID_GEN_TYPES
def test_generator_type_diversity(self) -> None:
"""100-sample batch should use at least 5 of 7 generator types."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100)
types = {ex["generator_type"] for ex in batch}
assert len(types) >= 5
def test_default_body_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
# default (3, 8), but loop/star may clamp
assert 2 <= ex["n_bodies"] <= 7
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
for label in ex["joint_labels"].values():
total = label["independent_constraints"] + label["redundant_constraints"]
assert total == label["total_constraints"]
def test_body_orientations_present(self) -> None:
"""Each example includes body_orientations as 3x3 lists."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
orients = ex["body_orientations"]
assert len(orients) == ex["n_bodies"]
for o in orients:
assert len(o) == 3
assert len(o[0]) == 3
def test_labels_structure(self) -> None:
"""Each example has labels dict with expected sub-keys."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
labels = ex["labels"]
assert "per_constraint" in labels
assert "per_joint" in labels
assert "per_body" in labels
assert "assembly" in labels
def test_grounded_field_present(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert isinstance(ex["grounded"], bool)
# ---------------------------------------------------------------------------
# Batch grounded ratio
# ---------------------------------------------------------------------------
class TestBatchGroundedRatio:
"""grounded_ratio controls the mix in batch generation."""
def test_all_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
assert all(ex["grounded"] for ex in batch)
def test_none_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
assert not any(ex["grounded"] for ex in batch)
def test_mixed_ratio(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
grounded_count = sum(1 for ex in batch if ex["grounded"])
# With 100 samples and p=0.5, should be roughly 50 +/- 20
assert 20 < grounded_count < 80
def test_batch_axis_strategy_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
axis_strategy="cardinal",
)
assert len(batch) == 10
def test_batch_parallel_axis_prob(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
parallel_axis_prob=0.5,
)
assert len(batch) == 10
# ---------------------------------------------------------------------------
# Seed reproducibility
# ---------------------------------------------------------------------------
class TestSeedReproducibility:
"""Different seeds produce different results."""
def test_different_seeds_differ(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=1)
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
b2 = g2.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]
r2 = [ex["is_rigid"] for ex in b2]
assert c1 != c2 or r1 != r2
def test_same_seed_identical(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=123)
g2 = SyntheticAssemblyGenerator(seed=123)
b1, j1, _ = g1.generate_tree_assembly(5)
b2, j2, _ = g2.generate_tree_assembly(5)
for a, b in zip(b1, b2, strict=True):
np.testing.assert_array_almost_equal(a.position, b.position)
assert len(j1) == len(j2)

View File

@@ -0,0 +1,267 @@
"""Tests for solver.datagen.jacobian -- Jacobian rank verification."""
from __future__ import annotations
import numpy as np
import pytest
from solver.datagen.jacobian import JacobianVerifier
from solver.datagen.types import Joint, JointType, RigidBody
# ---------------------------------------------------------------------------
# Fixtures
# ---------------------------------------------------------------------------
def _two_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
]
def _three_bodies() -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([4.0, 0.0, 0.0])),
]
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestJacobianShape:
"""Verify Jacobian matrix dimensions for each joint type."""
@pytest.mark.parametrize(
"joint_type,expected_rows",
[
(JointType.FIXED, 6),
(JointType.REVOLUTE, 5),
(JointType.CYLINDRICAL, 4),
(JointType.SLIDER, 5),
(JointType.BALL, 3),
(JointType.PLANAR, 3),
(JointType.SCREW, 5),
(JointType.UNIVERSAL, 4),
(JointType.PARALLEL, 3),
(JointType.PERPENDICULAR, 1),
(JointType.DISTANCE, 1),
],
)
def test_row_count(self, joint_type: JointType, expected_rows: int) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
joint = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=joint_type,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
n_added = v.add_joint_constraints(joint)
assert n_added == expected_rows
j = v.get_jacobian()
assert j.shape == (expected_rows, 12) # 2 bodies * 6 cols
class TestNumericalRank:
"""Numerical rank checks for known configurations."""
def test_fixed_joint_rank_six(self) -> None:
"""Fixed joint between 2 bodies: rank = 6."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 6
def test_revolute_joint_rank_five(self) -> None:
"""Revolute joint between 2 bodies: rank = 5."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 5
def test_ball_joint_rank_three(self) -> None:
"""Ball joint between 2 bodies: rank = 3."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.BALL,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.numerical_rank() == 3
def test_empty_jacobian_rank_zero(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
assert v.numerical_rank() == 0
class TestParallelAxesDegeneracy:
"""Parallel revolute axes create geometric dependencies."""
def _four_body_loop(self) -> list[RigidBody]:
return [
RigidBody(0, position=np.array([0.0, 0.0, 0.0])),
RigidBody(1, position=np.array([2.0, 0.0, 0.0])),
RigidBody(2, position=np.array([2.0, 2.0, 0.0])),
RigidBody(3, position=np.array([0.0, 2.0, 0.0])),
]
def _loop_joints(self, axes: list[np.ndarray]) -> list[Joint]:
pairs = [(0, 1, [1, 0, 0]), (1, 2, [2, 1, 0]), (2, 3, [1, 2, 0]), (3, 0, [0, 1, 0])]
return [
Joint(
joint_id=i,
body_a=a,
body_b=b,
joint_type=JointType.REVOLUTE,
anchor_a=np.array(anc, dtype=float),
anchor_b=np.array(anc, dtype=float),
axis=axes[i],
)
for i, (a, b, anc) in enumerate(pairs)
]
def test_parallel_has_lower_rank(self) -> None:
"""4-body closed loop: all-parallel revolute axes produce lower
Jacobian rank than mixed axes due to geometric dependency."""
bodies = self._four_body_loop()
z_axis = np.array([0.0, 0.0, 1.0])
# All axes parallel to Z
v_par = JacobianVerifier(bodies)
for j in self._loop_joints([z_axis] * 4):
v_par.add_joint_constraints(j)
rank_par = v_par.numerical_rank()
# Mixed axes
mixed = [
np.array([0.0, 0.0, 1.0]),
np.array([0.0, 1.0, 0.0]),
np.array([0.0, 0.0, 1.0]),
np.array([1.0, 0.0, 0.0]),
]
v_mix = JacobianVerifier(bodies)
for j in self._loop_joints(mixed):
v_mix.add_joint_constraints(j)
rank_mix = v_mix.numerical_rank()
assert rank_par < rank_mix
class TestFindDependencies:
"""Dependency detection."""
def test_fixed_joint_no_dependencies(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=0,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
v.add_joint_constraints(j)
assert v.find_dependencies() == []
def test_duplicate_fixed_has_dependencies(self) -> None:
"""Two fixed joints on same pair: second is fully dependent."""
bodies = _two_bodies()
v = JacobianVerifier(bodies)
for jid in range(2):
v.add_joint_constraints(
Joint(
joint_id=jid,
body_a=0,
body_b=1,
joint_type=JointType.FIXED,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
)
)
deps = v.find_dependencies()
assert len(deps) == 6 # Second fixed joint entirely redundant
def test_empty_no_dependencies(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
assert v.find_dependencies() == []
class TestRowLabels:
"""Row label metadata."""
def test_labels_match_rows(self) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
j = Joint(
joint_id=7,
body_a=0,
body_b=1,
joint_type=JointType.REVOLUTE,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([1.0, 0.0, 0.0]),
axis=np.array([0.0, 0.0, 1.0]),
)
v.add_joint_constraints(j)
assert len(v.row_labels) == 5
assert all(lab["joint_id"] == 7 for lab in v.row_labels)
class TestPerpendicularPair:
"""Internal _perpendicular_pair utility."""
@pytest.mark.parametrize(
"axis",
[
np.array([1.0, 0.0, 0.0]),
np.array([0.0, 1.0, 0.0]),
np.array([0.0, 0.0, 1.0]),
np.array([1.0, 1.0, 1.0]) / np.sqrt(3),
],
)
def test_orthonormal(self, axis: np.ndarray) -> None:
bodies = _two_bodies()
v = JacobianVerifier(bodies)
t1, t2 = v._perpendicular_pair(axis)
# All unit length
np.testing.assert_allclose(np.linalg.norm(t1), 1.0, atol=1e-12)
np.testing.assert_allclose(np.linalg.norm(t2), 1.0, atol=1e-12)
# Mutually perpendicular
np.testing.assert_allclose(np.dot(axis, t1), 0.0, atol=1e-12)
np.testing.assert_allclose(np.dot(axis, t2), 0.0, atol=1e-12)
np.testing.assert_allclose(np.dot(t1, t2), 0.0, atol=1e-12)

View File

@@ -0,0 +1,346 @@
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
from __future__ import annotations
import json
import numpy as np
from solver.datagen.labeling import (
label_assembly,
)
from solver.datagen.types import Joint, JointType, RigidBody
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
def _make_joint(
jid: int,
a: int,
b: int,
jtype: JointType,
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
) -> Joint:
return Joint(
joint_id=jid,
body_a=a,
body_b=b,
joint_type=jtype,
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
axis=np.array(axis),
)
# ---------------------------------------------------------------------------
# Per-constraint labels
# ---------------------------------------------------------------------------
class TestConstraintLabels:
"""Per-constraint labels combine pebble game and Jacobian results."""
def test_fixed_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 6
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_revolute_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 5
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_chain_constraint_count(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 10 # 5 + 5
def test_constraint_joint_ids(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
assert len(j0_constraints) == 5 # revolute
assert len(j1_constraints) == 3 # ball
def test_overconstrained_has_pebble_redundant(self) -> None:
"""Triangle with revolute joints: some constraints redundant."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
assert pebble_redundant > 0
# ---------------------------------------------------------------------------
# Per-joint labels
# ---------------------------------------------------------------------------
class TestJointLabels:
"""Per-joint aggregated labels."""
def test_fixed_joint_counts(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_joint) == 1
jl = labels.per_joint[0]
assert jl.joint_id == 0
assert jl.independent_count == 6
assert jl.redundant_count == 0
assert jl.total == 6
def test_overconstrained_has_redundant_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
assert total_redundant > 0
def test_joint_total_equals_dof(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.BALL)]
labels = label_assembly(bodies, joints, ground_body=0)
jl = labels.per_joint[0]
assert jl.total == 3 # ball has 3 DOF
# ---------------------------------------------------------------------------
# Per-body DOF labels
# ---------------------------------------------------------------------------
class TestBodyDofLabels:
"""Per-body DOF signatures from nullspace projection."""
def test_fixed_joint_grounded_both_zero(self) -> None:
"""Two bodies + fixed joint + grounded: both fully constrained."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert bl.translational_dof == 0
assert bl.rotational_dof == 0
def test_revolute_has_rotational_dof(self) -> None:
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b1 = next(b for b in labels.per_body if b.body_id == 1)
# Revolute allows 1 rotation DOF
assert b1.rotational_dof >= 1
def test_dof_bounds(self) -> None:
"""All DOF values should be in [0, 3]."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert 0 <= bl.translational_dof <= 3
assert 0 <= bl.rotational_dof <= 3
def test_floating_more_dof_than_grounded(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
grounded = label_assembly(bodies, joints, ground_body=0)
floating = label_assembly(bodies, joints, ground_body=None)
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
assert f_total > g_total
def test_grounded_body_zero_dof(self) -> None:
"""The grounded body should have 0 DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b0 = next(b for b in labels.per_body if b.body_id == 0)
assert b0.translational_dof == 0
assert b0.rotational_dof == 0
def test_body_count_matches(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_body) == 3
# ---------------------------------------------------------------------------
# Assembly label
# ---------------------------------------------------------------------------
class TestAssemblyLabel:
"""Assembly-wide summary labels."""
def test_underconstrained_chain(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "underconstrained"
assert labels.assembly.is_rigid is False
def test_well_constrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "well-constrained"
assert labels.assembly.is_rigid is True
assert labels.assembly.is_minimally_rigid is True
def test_overconstrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.redundant_count > 0
def test_has_degeneracy_with_parallel_axes(self) -> None:
"""Parallel revolute axes in a loop create geometric degeneracy."""
z_axis = (0.0, 0.0, 1.0)
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.has_degeneracy is True
# ---------------------------------------------------------------------------
# Serialization
# ---------------------------------------------------------------------------
class TestToDict:
"""to_dict produces JSON-serializable output."""
def test_top_level_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d.keys()) == {
"per_constraint",
"per_joint",
"per_body",
"assembly",
}
def test_per_constraint_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
for item in d["per_constraint"]:
assert set(item.keys()) == {
"joint_id",
"constraint_idx",
"pebble_independent",
"jacobian_independent",
}
def test_assembly_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d["assembly"].keys()) == {
"classification",
"total_dof",
"redundant_count",
"is_rigid",
"is_minimally_rigid",
"has_degeneracy",
}
def test_json_serializable(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
# Should not raise
serialized = json.dumps(d)
assert isinstance(serialized, str)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestLabelAssemblyEdgeCases:
"""Edge cases for label_assembly."""
def test_no_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
labels = label_assembly(bodies, [], ground_body=0)
assert len(labels.per_constraint) == 0
assert len(labels.per_joint) == 0
assert labels.assembly.classification == "underconstrained"
# Non-ground body should be fully free
b1 = next(b for b in labels.per_body if b.body_id == 1)
assert b1.translational_dof == 3
assert b1.rotational_dof == 3
def test_no_joints_floating(self) -> None:
bodies = _make_bodies((0, 0, 0))
labels = label_assembly(bodies, [], ground_body=None)
assert len(labels.per_body) == 1
assert labels.per_body[0].translational_dof == 3
assert labels.per_body[0].rotational_dof == 3
def test_analysis_embedded(self) -> None:
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
analysis = labels.analysis
assert hasattr(analysis, "combinatorial_classification")
assert hasattr(analysis, "jacobian_rank")
assert hasattr(analysis, "is_rigid")

View File

@@ -0,0 +1,206 @@
"""Tests for solver.datagen.pebble_game -- (6,6)-pebble game."""
from __future__ import annotations
import itertools
import numpy as np
import pytest
from solver.datagen.pebble_game import PebbleGame3D
from solver.datagen.types import Joint, JointType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _revolute(jid: int, a: int, b: int, axis: np.ndarray | None = None) -> Joint:
"""Shorthand for a revolute joint between bodies *a* and *b*."""
if axis is None:
axis = np.array([0.0, 0.0, 1.0])
return Joint(
joint_id=jid,
body_a=a,
body_b=b,
joint_type=JointType.REVOLUTE,
axis=axis,
)
def _fixed(jid: int, a: int, b: int) -> Joint:
return Joint(joint_id=jid, body_a=a, body_b=b, joint_type=JointType.FIXED)
# ---------------------------------------------------------------------------
# Tests
# ---------------------------------------------------------------------------
class TestAddBody:
"""Body registration basics."""
def test_single_body_six_pebbles(self) -> None:
pg = PebbleGame3D()
pg.add_body(0)
assert pg.state.free_pebbles[0] == 6
def test_duplicate_body_no_op(self) -> None:
pg = PebbleGame3D()
pg.add_body(0)
pg.add_body(0)
assert pg.state.free_pebbles[0] == 6
def test_multiple_bodies(self) -> None:
pg = PebbleGame3D()
for i in range(5):
pg.add_body(i)
assert pg.get_dof() == 30 # 5 * 6
class TestAddJoint:
"""Joint insertion and DOF accounting."""
def test_revolute_removes_five_dof(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
assert len(results) == 5 # 5 scalar constraints
assert all(r["independent"] for r in results)
# 2 bodies * 6 = 12, minus 5 independent = 7 free pebbles
assert pg.get_dof() == 7
def test_fixed_removes_six_dof(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_fixed(0, 0, 1))
assert len(results) == 6
assert all(r["independent"] for r in results)
assert pg.get_dof() == 6
def test_ball_removes_three_dof(self) -> None:
pg = PebbleGame3D()
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.BALL)
results = pg.add_joint(j)
assert len(results) == 3
assert all(r["independent"] for r in results)
assert pg.get_dof() == 9
class TestTwoBodiesRevolute:
"""Two bodies connected by a revolute -- demo scenario 1."""
def test_internal_dof(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
# Total DOF = 7, internal = 7 - 6 = 1
assert pg.get_internal_dof() == 1
def test_not_rigid(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
assert not pg.is_rigid()
def test_classification(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
assert pg.classify_assembly() == "underconstrained"
class TestTwoBodiesFixed:
"""Two bodies + fixed joint -- demo scenario 2."""
def test_zero_internal_dof(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.get_internal_dof() == 0
def test_rigid(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.is_rigid()
def test_well_constrained(self) -> None:
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
assert pg.classify_assembly() == "well-constrained"
class TestTriangleRevolute:
"""Triangle of 3 bodies with revolute joints -- demo scenario 3."""
@pytest.fixture()
def pg(self) -> PebbleGame3D:
pg = PebbleGame3D()
pg.add_joint(_revolute(0, 0, 1))
pg.add_joint(_revolute(1, 1, 2))
pg.add_joint(_revolute(2, 2, 0))
return pg
def test_has_redundant_edges(self, pg: PebbleGame3D) -> None:
assert pg.get_redundant_count() > 0
def test_classification_overconstrained(self, pg: PebbleGame3D) -> None:
# 15 constraints on 3 bodies (Maxwell: 6*3-6=12 needed)
assert pg.classify_assembly() in ("overconstrained", "mixed")
def test_rigid(self, pg: PebbleGame3D) -> None:
assert pg.is_rigid()
class TestChainNotRigid:
"""A serial chain of 4 bodies with revolute joints is never rigid."""
def test_chain_underconstrained(self) -> None:
pg = PebbleGame3D()
for i in range(3):
pg.add_joint(_revolute(i, i, i + 1))
assert not pg.is_rigid()
assert pg.classify_assembly() == "underconstrained"
def test_chain_internal_dof(self) -> None:
pg = PebbleGame3D()
for i in range(3):
pg.add_joint(_revolute(i, i, i + 1))
# 4 bodies * 6 = 24, minus 15 independent = 9 free, internal = 3
assert pg.get_internal_dof() == 3
class TestEdgeResults:
"""Result dicts returned by add_joint."""
def test_result_keys(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
expected_keys = {"edge_id", "joint_id", "constraint_index", "independent", "dof_remaining"}
for r in results:
assert set(r.keys()) == expected_keys
def test_edge_ids_sequential(self) -> None:
pg = PebbleGame3D()
r1 = pg.add_joint(_revolute(0, 0, 1))
r2 = pg.add_joint(_revolute(1, 1, 2))
all_ids = [r["edge_id"] for r in r1 + r2]
assert all_ids == list(range(10))
def test_dof_remaining_monotonic(self) -> None:
pg = PebbleGame3D()
results = pg.add_joint(_revolute(0, 0, 1))
dofs = [r["dof_remaining"] for r in results]
# Should be non-increasing (each independent edge removes a pebble)
for a, b in itertools.pairwise(dofs):
assert a >= b
class TestGroundedClassification:
"""classify_assembly with grounded=True."""
def test_grounded_baseline_zero(self) -> None:
"""With grounded=True the baseline is 0 (not 6)."""
pg = PebbleGame3D()
pg.add_joint(_fixed(0, 0, 1))
# Ungrounded: well-constrained (6 pebbles = baseline 6)
assert pg.classify_assembly(grounded=False) == "well-constrained"
# Grounded: the 6 remaining pebbles on body 1 exceed baseline 0,
# so the raw pebble game (without a virtual ground body) sees this
# as underconstrained. The analysis function handles this properly
# by adding a virtual ground body.
assert pg.classify_assembly(grounded=True) == "underconstrained"

163
tests/datagen/test_types.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for solver.datagen.types -- shared data types."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.datagen.types import (
ConstraintAnalysis,
Joint,
JointType,
PebbleState,
RigidBody,
)
class TestJointType:
"""JointType enum construction and DOF values."""
EXPECTED_DOF: ClassVar[dict[str, int]] = {
"FIXED": 6,
"REVOLUTE": 5,
"CYLINDRICAL": 4,
"SLIDER": 5,
"BALL": 3,
"PLANAR": 3,
"SCREW": 5,
"UNIVERSAL": 4,
"PARALLEL": 3,
"PERPENDICULAR": 1,
"DISTANCE": 1,
}
def test_member_count(self) -> None:
assert len(JointType) == 11
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
def test_dof_values(self, name: str, dof: int) -> None:
assert JointType[name].dof == dof
def test_access_by_name(self) -> None:
assert JointType["REVOLUTE"] is JointType.REVOLUTE
def test_value_is_tuple(self) -> None:
assert JointType.REVOLUTE.value == (1, 5)
assert JointType.REVOLUTE.dof == 5
class TestRigidBody:
"""RigidBody dataclass defaults and construction."""
def test_defaults(self) -> None:
body = RigidBody(body_id=0)
np.testing.assert_array_equal(body.position, np.zeros(3))
np.testing.assert_array_equal(body.orientation, np.eye(3))
assert body.local_anchors == {}
def test_custom_position(self) -> None:
pos = np.array([1.0, 2.0, 3.0])
body = RigidBody(body_id=7, position=pos)
np.testing.assert_array_equal(body.position, pos)
assert body.body_id == 7
def test_local_anchors_mutable(self) -> None:
body = RigidBody(body_id=0)
body.local_anchors["top"] = np.array([0.0, 0.0, 1.0])
assert "top" in body.local_anchors
def test_default_factory_isolation(self) -> None:
"""Each instance gets its own default containers."""
b1 = RigidBody(body_id=0)
b2 = RigidBody(body_id=1)
b1.local_anchors["x"] = np.zeros(3)
assert "x" not in b2.local_anchors
class TestJoint:
"""Joint dataclass defaults and construction."""
def test_defaults(self) -> None:
j = Joint(joint_id=0, body_a=0, body_b=1, joint_type=JointType.REVOLUTE)
np.testing.assert_array_equal(j.anchor_a, np.zeros(3))
np.testing.assert_array_equal(j.anchor_b, np.zeros(3))
np.testing.assert_array_equal(j.axis, np.array([0.0, 0.0, 1.0]))
assert j.pitch == 0.0
def test_full_construction(self) -> None:
j = Joint(
joint_id=5,
body_a=2,
body_b=3,
joint_type=JointType.SCREW,
anchor_a=np.array([1.0, 0.0, 0.0]),
anchor_b=np.array([2.0, 0.0, 0.0]),
axis=np.array([1.0, 0.0, 0.0]),
pitch=0.5,
)
assert j.joint_id == 5
assert j.joint_type is JointType.SCREW
assert j.pitch == 0.5
class TestPebbleState:
"""PebbleState dataclass defaults."""
def test_defaults(self) -> None:
s = PebbleState()
assert s.free_pebbles == {}
assert s.directed_edges == {}
assert s.independent_edges == set()
assert s.redundant_edges == set()
assert s.incoming == {}
assert s.outgoing == {}
def test_default_factory_isolation(self) -> None:
s1 = PebbleState()
s2 = PebbleState()
s1.free_pebbles[0] = 6
assert 0 not in s2.free_pebbles
class TestConstraintAnalysis:
"""ConstraintAnalysis dataclass construction."""
def test_construction(self) -> None:
ca = ConstraintAnalysis(
combinatorial_dof=6,
combinatorial_internal_dof=0,
combinatorial_redundant=0,
combinatorial_classification="well-constrained",
per_edge_results=[],
jacobian_rank=6,
jacobian_nullity=0,
jacobian_internal_dof=0,
numerically_dependent=[],
geometric_degeneracies=0,
is_rigid=True,
is_minimally_rigid=True,
)
assert ca.is_rigid is True
assert ca.is_minimally_rigid is True
assert ca.combinatorial_classification == "well-constrained"
def test_per_edge_results_typing(self) -> None:
"""per_edge_results accepts list[dict[str, Any]]."""
ca = ConstraintAnalysis(
combinatorial_dof=7,
combinatorial_internal_dof=1,
combinatorial_redundant=0,
combinatorial_classification="underconstrained",
per_edge_results=[{"edge_id": 0, "independent": True}],
jacobian_rank=5,
jacobian_nullity=1,
jacobian_internal_dof=1,
numerically_dependent=[],
geometric_degeneracies=0,
is_rigid=False,
is_minimally_rigid=False,
)
assert len(ca.per_edge_results) == 1
assert ca.per_edge_results[0]["edge_id"] == 0

0
tests/mates/__init__.py Normal file
View File

View File

@@ -0,0 +1,287 @@
"""Tests for solver.mates.conversion -- mate-to-joint conversion."""
from __future__ import annotations
import numpy as np
from solver.datagen.types import JointType, RigidBody
from solver.mates.conversion import (
MateAnalysisResult,
analyze_mate_assembly,
convert_mates_to_joints,
)
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ref(
body_id: int,
geom_type: GeometryType,
*,
origin: np.ndarray | None = None,
direction: np.ndarray | None = None,
) -> GeometryRef:
"""Factory for GeometryRef with sensible defaults."""
if origin is None:
origin = np.zeros(3)
if direction is None and geom_type in {
GeometryType.FACE,
GeometryType.AXIS,
GeometryType.PLANE,
}:
direction = np.array([0.0, 0.0, 1.0])
return GeometryRef(
body_id=body_id,
geometry_type=geom_type,
geometry_id="Geom001",
origin=origin,
direction=direction,
)
def _make_bodies(n: int) -> list[RigidBody]:
"""Create n bodies at distinct positions."""
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
# ---------------------------------------------------------------------------
# convert_mates_to_joints
# ---------------------------------------------------------------------------
class TestConvertMatesToJoints:
"""convert_mates_to_joints function."""
def test_empty_input(self) -> None:
joints, m2j, j2m = convert_mates_to_joints([])
assert joints == []
assert m2j == {}
assert j2m == {}
def test_hinge_pattern(self) -> None:
"""Concentric + Coincident(plane) -> single REVOLUTE joint."""
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS),
ref_b=_make_ref(1, GeometryType.AXIS),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
]
joints, m2j, j2m = convert_mates_to_joints(mates)
assert len(joints) == 1
assert joints[0].joint_type is JointType.REVOLUTE
assert joints[0].body_a == 0
assert joints[0].body_b == 1
# Both mates map to the single joint
assert 0 in m2j
assert 1 in m2j
assert j2m[joints[0].joint_id] == [0, 1]
def test_lock_pattern(self) -> None:
"""Lock -> FIXED joint."""
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
joints, _m2j, _j2m = convert_mates_to_joints(mates)
assert len(joints) == 1
assert joints[0].joint_type is JointType.FIXED
def test_unmatched_mate_fallback(self) -> None:
"""A single ANGLE mate with no pattern -> individual joint."""
mates = [
Mate(
mate_id=0,
mate_type=MateType.ANGLE,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
joints, _m2j, _j2m = convert_mates_to_joints(mates)
assert len(joints) == 1
assert joints[0].joint_type is JointType.PERPENDICULAR
def test_mapping_consistency(self) -> None:
"""mate_to_joint and joint_to_mates are consistent."""
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS),
ref_b=_make_ref(1, GeometryType.AXIS),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
Mate(
mate_id=2,
mate_type=MateType.DISTANCE,
ref_a=_make_ref(2, GeometryType.POINT),
ref_b=_make_ref(3, GeometryType.POINT),
),
]
joints, m2j, j2m = convert_mates_to_joints(mates)
# Every mate should be in m2j
for mate in mates:
assert mate.mate_id in m2j
# Every joint should be in j2m
for joint in joints:
assert joint.joint_id in j2m
def test_joint_axis_from_geometry(self) -> None:
"""Joint axis should come from mate geometry direction."""
axis_dir = np.array([1.0, 0.0, 0.0])
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS, direction=axis_dir),
ref_b=_make_ref(1, GeometryType.AXIS, direction=axis_dir),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
]
joints, _, _ = convert_mates_to_joints(mates)
np.testing.assert_array_almost_equal(joints[0].axis, axis_dir)
# ---------------------------------------------------------------------------
# MateAnalysisResult
# ---------------------------------------------------------------------------
class TestMateAnalysisResult:
"""MateAnalysisResult dataclass."""
def test_to_dict(self) -> None:
result = MateAnalysisResult(
patterns=[],
joints=[],
)
d = result.to_dict()
assert d["patterns"] == []
assert d["joints"] == []
assert d["labels"] is None
# ---------------------------------------------------------------------------
# analyze_mate_assembly
# ---------------------------------------------------------------------------
class TestAnalyzeMateAssembly:
"""Full pipeline: mates -> joints -> analysis."""
def test_two_bodies_hinge(self) -> None:
"""Two bodies connected by hinge mates -> underconstrained (1 DOF)."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS),
ref_b=_make_ref(1, GeometryType.AXIS),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
]
result = analyze_mate_assembly(bodies, mates)
assert result.analysis is not None
assert result.labels is not None
# A revolute joint removes 5 DOF, leaving 1 internal DOF
assert result.analysis.combinatorial_internal_dof == 1
assert len(result.joints) == 1
assert result.joints[0].joint_type is JointType.REVOLUTE
def test_two_bodies_fixed(self) -> None:
"""Two bodies with lock mate -> well-constrained."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = analyze_mate_assembly(bodies, mates)
assert result.analysis is not None
assert result.analysis.combinatorial_internal_dof == 0
assert result.analysis.is_rigid
def test_grounded_assembly(self) -> None:
"""Grounded assembly analysis works."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = analyze_mate_assembly(bodies, mates, ground_body=0)
assert result.analysis is not None
assert result.analysis.is_rigid
def test_no_mates(self) -> None:
"""Assembly with no mates should be fully underconstrained."""
bodies = _make_bodies(2)
result = analyze_mate_assembly(bodies, [])
assert result.analysis is not None
assert result.analysis.combinatorial_internal_dof == 6
assert len(result.joints) == 0
def test_single_body(self) -> None:
"""Single body, no mates."""
bodies = _make_bodies(1)
result = analyze_mate_assembly(bodies, [])
assert result.analysis is not None
assert len(result.joints) == 0
def test_result_traceability(self) -> None:
"""mate_to_joint and joint_to_mates populated in result."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS),
ref_b=_make_ref(1, GeometryType.AXIS),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
]
result = analyze_mate_assembly(bodies, mates)
assert 0 in result.mate_to_joint
assert 1 in result.mate_to_joint
assert len(result.joint_to_mates) > 0

View File

@@ -0,0 +1,155 @@
"""Tests for solver.mates.generator -- synthetic mate generator."""
from __future__ import annotations
from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch
from solver.mates.primitives import MateType
# ---------------------------------------------------------------------------
# SyntheticMateGenerator
# ---------------------------------------------------------------------------
class TestSyntheticMateGenerator:
"""SyntheticMateGenerator core functionality."""
def test_generate_basic(self) -> None:
"""Generate a simple assembly with mates."""
gen = SyntheticMateGenerator(seed=42)
bodies, mates, result = gen.generate(3)
assert len(bodies) == 3
assert len(mates) > 0
assert result.analysis is not None
def test_deterministic_with_seed(self) -> None:
"""Same seed produces same output."""
gen1 = SyntheticMateGenerator(seed=123)
_, mates1, _ = gen1.generate(3)
gen2 = SyntheticMateGenerator(seed=123)
_, mates2, _ = gen2.generate(3)
assert len(mates1) == len(mates2)
for m1, m2 in zip(mates1, mates2, strict=True):
assert m1.mate_type == m2.mate_type
assert m1.ref_a.body_id == m2.ref_a.body_id
def test_grounded(self) -> None:
"""Grounded assembly should work."""
gen = SyntheticMateGenerator(seed=42)
bodies, _mates, result = gen.generate(3, grounded=True)
assert len(bodies) == 3
assert result.analysis is not None
def test_revolute_produces_two_mates(self) -> None:
"""A revolute joint should reverse-map to 2 mates."""
gen = SyntheticMateGenerator(seed=42)
_bodies, mates, _result = gen.generate(2)
# 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident)
assert len(mates) == 2
mate_types = {m.mate_type for m in mates}
assert MateType.CONCENTRIC in mate_types
assert MateType.COINCIDENT in mate_types
class TestReverseMapping:
"""Reverse mapping from joints to mates."""
def test_revolute_mapping(self) -> None:
"""REVOLUTE -> Concentric + Coincident."""
gen = SyntheticMateGenerator(seed=42)
_bodies, mates, _result = gen.generate(2)
types = [m.mate_type for m in mates]
assert MateType.CONCENTRIC in types
assert MateType.COINCIDENT in types
def test_round_trip_analysis(self) -> None:
"""Generated mates round-trip through analysis successfully."""
gen = SyntheticMateGenerator(seed=42)
_bodies, _mates, result = gen.generate(4)
assert result.analysis is not None
assert result.labels is not None
# Should produce joints from the mates
assert len(result.joints) > 0
class TestNoiseInjection:
"""Noise injection mechanisms."""
def test_redundant_injection(self) -> None:
"""Redundant prob > 0 produces more mates than clean version."""
gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0)
_, mates_clean, _ = gen_clean.generate(4)
gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0)
_, mates_noisy, _ = gen_noisy.generate(4)
assert len(mates_noisy) > len(mates_clean)
def test_missing_injection(self) -> None:
"""Missing prob > 0 produces fewer mates than clean version."""
gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0)
_, mates_clean, _ = gen_clean.generate(4)
gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5)
_, mates_noisy, _ = gen_noisy.generate(4)
# With 50% drop rate on 6 mates, very likely to drop at least one
assert len(mates_noisy) <= len(mates_clean)
def test_incompatible_injection(self) -> None:
"""Incompatible prob > 0 adds mates with wrong geometry."""
gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0)
_, mates, _ = gen.generate(3)
# Should have extra mates beyond the clean count
gen_clean = SyntheticMateGenerator(seed=42)
_, mates_clean, _ = gen_clean.generate(3)
assert len(mates) > len(mates_clean)
# ---------------------------------------------------------------------------
# generate_mate_training_batch
# ---------------------------------------------------------------------------
class TestGenerateMateTrainingBatch:
"""Batch generation function."""
def test_batch_structure(self) -> None:
"""Each example has required keys."""
examples = generate_mate_training_batch(batch_size=3, seed=42)
assert len(examples) == 3
for ex in examples:
assert "bodies" in ex
assert "mates" in ex
assert "patterns" in ex
assert "labels" in ex
assert "n_bodies" in ex
assert "n_mates" in ex
assert "n_joints" in ex
def test_batch_deterministic(self) -> None:
"""Same seed produces same batch."""
batch1 = generate_mate_training_batch(batch_size=5, seed=99)
batch2 = generate_mate_training_batch(batch_size=5, seed=99)
for ex1, ex2 in zip(batch1, batch2, strict=True):
assert ex1["n_bodies"] == ex2["n_bodies"]
assert ex1["n_mates"] == ex2["n_mates"]
def test_batch_grounded_ratio(self) -> None:
"""Batch respects grounded_ratio parameter."""
# All grounded
examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0)
assert len(examples) == 5
def test_batch_with_noise(self) -> None:
"""Batch with noise injection runs without error."""
examples = generate_mate_training_batch(
batch_size=3,
seed=42,
redundant_prob=0.3,
missing_prob=0.1,
)
assert len(examples) == 3
for ex in examples:
assert ex["n_mates"] >= 0

View File

@@ -0,0 +1,224 @@
"""Tests for solver.mates.labeling -- mate-level ground truth labels."""
from __future__ import annotations
import numpy as np
from solver.datagen.types import RigidBody
from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly
from solver.mates.patterns import JointPattern
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ref(
body_id: int,
geom_type: GeometryType,
*,
origin: np.ndarray | None = None,
direction: np.ndarray | None = None,
) -> GeometryRef:
"""Factory for GeometryRef with sensible defaults."""
if origin is None:
origin = np.zeros(3)
if direction is None and geom_type in {
GeometryType.FACE,
GeometryType.AXIS,
GeometryType.PLANE,
}:
direction = np.array([0.0, 0.0, 1.0])
return GeometryRef(
body_id=body_id,
geometry_type=geom_type,
geometry_id="Geom001",
origin=origin,
direction=direction,
)
def _make_bodies(n: int) -> list[RigidBody]:
"""Create n bodies at distinct positions."""
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
# ---------------------------------------------------------------------------
# MateLabel
# ---------------------------------------------------------------------------
class TestMateLabel:
"""MateLabel dataclass."""
def test_defaults(self) -> None:
ml = MateLabel(mate_id=0)
assert ml.is_independent is True
assert ml.is_redundant is False
assert ml.is_degenerate is False
assert ml.pattern is None
assert ml.issue is None
def test_to_dict(self) -> None:
ml = MateLabel(
mate_id=5,
is_independent=False,
is_redundant=True,
pattern=JointPattern.HINGE,
issue="redundant",
)
d = ml.to_dict()
assert d["mate_id"] == 5
assert d["is_redundant"] is True
assert d["pattern"] == "hinge"
assert d["issue"] == "redundant"
def test_to_dict_none_pattern(self) -> None:
ml = MateLabel(mate_id=0)
d = ml.to_dict()
assert d["pattern"] is None
# ---------------------------------------------------------------------------
# MateAssemblyLabels
# ---------------------------------------------------------------------------
class TestMateAssemblyLabels:
"""MateAssemblyLabels dataclass."""
def test_to_dict_structure(self) -> None:
"""to_dict produces expected keys."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = label_mate_assembly(bodies, mates)
d = result.to_dict()
assert "per_mate" in d
assert "patterns" in d
assert "assembly" in d
assert isinstance(d["per_mate"], list)
# ---------------------------------------------------------------------------
# label_mate_assembly
# ---------------------------------------------------------------------------
class TestLabelMateAssembly:
"""Full labeling pipeline."""
def test_clean_assembly_no_redundancy(self) -> None:
"""Two bodies with lock mate -> clean, no redundancy."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = label_mate_assembly(bodies, mates)
assert isinstance(result, MateAssemblyLabels)
assert len(result.per_mate) == 1
ml = result.per_mate[0]
assert ml.mate_id == 0
assert ml.is_independent is True
assert ml.is_redundant is False
assert ml.issue is None
def test_redundant_assembly(self) -> None:
"""Two lock mates on same body pair -> one is redundant."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
Mate(
mate_id=1,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
),
]
result = label_mate_assembly(bodies, mates)
assert len(result.per_mate) == 2
redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant)
# At least one should be redundant
assert redundant_count >= 1
assert result.assembly.redundant_count > 0
def test_hinge_pattern_labeling(self) -> None:
"""Hinge mates get pattern membership."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.CONCENTRIC,
ref_a=_make_ref(0, GeometryType.AXIS),
ref_b=_make_ref(1, GeometryType.AXIS),
),
Mate(
mate_id=1,
mate_type=MateType.COINCIDENT,
ref_a=_make_ref(0, GeometryType.PLANE),
ref_b=_make_ref(1, GeometryType.PLANE),
),
]
result = label_mate_assembly(bodies, mates)
assert len(result.per_mate) == 2
# Both mates should be part of the hinge pattern
for ml in result.per_mate:
assert ml.pattern is JointPattern.HINGE
assert ml.is_independent is True
def test_grounded_assembly(self) -> None:
"""Grounded assembly labeling works."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = label_mate_assembly(bodies, mates, ground_body=0)
assert result.assembly.is_rigid
def test_empty_mates(self) -> None:
"""No mates -> no per_mate labels, underconstrained."""
bodies = _make_bodies(2)
result = label_mate_assembly(bodies, [])
assert len(result.per_mate) == 0
assert result.assembly.classification == "underconstrained"
def test_assembly_classification(self) -> None:
"""Assembly classification is present."""
bodies = _make_bodies(2)
mates = [
Mate(
mate_id=0,
mate_type=MateType.LOCK,
ref_a=_make_ref(0, GeometryType.FACE),
ref_b=_make_ref(1, GeometryType.FACE),
),
]
result = label_mate_assembly(bodies, mates)
assert result.assembly.classification in {
"well-constrained",
"overconstrained",
"underconstrained",
"mixed",
}

View File

@@ -0,0 +1,285 @@
"""Tests for solver.mates.patterns -- joint pattern recognition."""
from __future__ import annotations
import numpy as np
from solver.datagen.types import JointType
from solver.mates.patterns import JointPattern, PatternMatch, recognize_patterns
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ref(
body_id: int,
geom_type: GeometryType,
*,
geometry_id: str = "Geom001",
origin: np.ndarray | None = None,
direction: np.ndarray | None = None,
) -> GeometryRef:
"""Factory for GeometryRef with sensible defaults."""
if origin is None:
origin = np.zeros(3)
if direction is None and geom_type in {
GeometryType.FACE,
GeometryType.AXIS,
GeometryType.PLANE,
}:
direction = np.array([0.0, 0.0, 1.0])
return GeometryRef(
body_id=body_id,
geometry_type=geom_type,
geometry_id=geometry_id,
origin=origin,
direction=direction,
)
def _make_mate(
mate_id: int,
mate_type: MateType,
body_a: int,
body_b: int,
geom_a: GeometryType = GeometryType.FACE,
geom_b: GeometryType = GeometryType.FACE,
) -> Mate:
"""Factory for Mate with body pair and geometry types."""
return Mate(
mate_id=mate_id,
mate_type=mate_type,
ref_a=_make_ref(body_a, geom_a),
ref_b=_make_ref(body_b, geom_b),
)
# ---------------------------------------------------------------------------
# JointPattern enum
# ---------------------------------------------------------------------------
class TestJointPattern:
"""JointPattern enum."""
def test_member_count(self) -> None:
assert len(JointPattern) == 9
def test_string_values(self) -> None:
for jp in JointPattern:
assert isinstance(jp.value, str)
def test_access_by_name(self) -> None:
assert JointPattern["HINGE"] is JointPattern.HINGE
# ---------------------------------------------------------------------------
# PatternMatch
# ---------------------------------------------------------------------------
class TestPatternMatch:
"""PatternMatch dataclass."""
def test_construction(self) -> None:
mate = _make_mate(0, MateType.LOCK, 0, 1)
pm = PatternMatch(
pattern=JointPattern.FIXED,
mates=[mate],
body_a=0,
body_b=1,
confidence=1.0,
equivalent_joint_type=JointType.FIXED,
)
assert pm.pattern is JointPattern.FIXED
assert pm.confidence == 1.0
assert pm.missing_mates == []
def test_to_dict(self) -> None:
mate = _make_mate(5, MateType.LOCK, 0, 1)
pm = PatternMatch(
pattern=JointPattern.FIXED,
mates=[mate],
body_a=0,
body_b=1,
confidence=1.0,
equivalent_joint_type=JointType.FIXED,
)
d = pm.to_dict()
assert d["pattern"] == "fixed"
assert d["mate_ids"] == [5]
assert d["equivalent_joint_type"] == "FIXED"
# ---------------------------------------------------------------------------
# recognize_patterns — canonical patterns
# ---------------------------------------------------------------------------
class TestRecognizeCanonical:
"""Full-confidence canonical pattern recognition."""
def test_empty_input(self) -> None:
assert recognize_patterns([]) == []
def test_hinge(self) -> None:
"""Concentric(axis) + Coincident(plane) -> Hinge."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
]
results = recognize_patterns(mates)
top = results[0]
assert top.pattern is JointPattern.HINGE
assert top.confidence == 1.0
assert top.equivalent_joint_type is JointType.REVOLUTE
assert top.missing_mates == []
def test_slider(self) -> None:
"""Coincident(plane) + Parallel(axis) -> Slider."""
mates = [
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
_make_mate(1, MateType.PARALLEL, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
]
results = recognize_patterns(mates)
top = results[0]
assert top.pattern is JointPattern.SLIDER
assert top.confidence == 1.0
assert top.equivalent_joint_type is JointType.SLIDER
def test_cylinder(self) -> None:
"""Concentric(axis) only -> Cylinder."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
]
results = recognize_patterns(mates)
# Should match cylinder at confidence 1.0
cylinder = [r for r in results if r.pattern is JointPattern.CYLINDER]
assert len(cylinder) >= 1
assert cylinder[0].confidence == 1.0
assert cylinder[0].equivalent_joint_type is JointType.CYLINDRICAL
def test_ball(self) -> None:
"""Coincident(point) -> Ball."""
mates = [
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.POINT, GeometryType.POINT),
]
results = recognize_patterns(mates)
top = results[0]
assert top.pattern is JointPattern.BALL
assert top.confidence == 1.0
assert top.equivalent_joint_type is JointType.BALL
def test_planar_face(self) -> None:
"""Coincident(face) -> Planar."""
mates = [
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.FACE, GeometryType.FACE),
]
results = recognize_patterns(mates)
top = results[0]
assert top.pattern is JointPattern.PLANAR
assert top.confidence == 1.0
assert top.equivalent_joint_type is JointType.PLANAR
def test_fixed(self) -> None:
"""Lock -> Fixed."""
mates = [
_make_mate(0, MateType.LOCK, 0, 1, GeometryType.FACE, GeometryType.FACE),
]
results = recognize_patterns(mates)
top = results[0]
assert top.pattern is JointPattern.FIXED
assert top.confidence == 1.0
assert top.equivalent_joint_type is JointType.FIXED
# ---------------------------------------------------------------------------
# recognize_patterns — partial matches
# ---------------------------------------------------------------------------
class TestRecognizePartial:
"""Partial pattern matches and hints."""
def test_concentric_without_plane_hints_hinge(self) -> None:
"""Concentric alone matches hinge at 0.5 confidence with missing hint."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
]
results = recognize_patterns(mates)
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
assert len(hinge_matches) >= 1
hinge = hinge_matches[0]
assert hinge.confidence == 0.5
assert len(hinge.missing_mates) > 0
def test_coincident_plane_without_parallel_hints_slider(self) -> None:
"""Coincident(plane) alone matches slider at 0.5 confidence."""
mates = [
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
]
results = recognize_patterns(mates)
slider_matches = [r for r in results if r.pattern is JointPattern.SLIDER]
assert len(slider_matches) >= 1
assert slider_matches[0].confidence == 0.5
# ---------------------------------------------------------------------------
# recognize_patterns — ambiguous / multi-body
# ---------------------------------------------------------------------------
class TestRecognizeAmbiguous:
"""Ambiguous patterns and multi-body-pair assemblies."""
def test_concentric_matches_both_hinge_and_cylinder(self) -> None:
"""A single concentric mate produces both hinge (partial) and cylinder matches."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
]
results = recognize_patterns(mates)
patterns = {r.pattern for r in results}
assert JointPattern.HINGE in patterns
assert JointPattern.CYLINDER in patterns
def test_multiple_body_pairs(self) -> None:
"""Mates across different body pairs produce separate pattern matches."""
mates = [
_make_mate(0, MateType.LOCK, 0, 1),
_make_mate(1, MateType.COINCIDENT, 2, 3, GeometryType.POINT, GeometryType.POINT),
]
results = recognize_patterns(mates)
pairs = {(r.body_a, r.body_b) for r in results}
assert (0, 1) in pairs
assert (2, 3) in pairs
def test_results_sorted_by_confidence(self) -> None:
"""All results should be sorted by confidence descending."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
_make_mate(1, MateType.LOCK, 2, 3),
]
results = recognize_patterns(mates)
confidences = [r.confidence for r in results]
assert confidences == sorted(confidences, reverse=True)
def test_unknown_pattern(self) -> None:
"""A mate type that matches no rule returns UNKNOWN."""
mates = [
_make_mate(0, MateType.ANGLE, 0, 1, GeometryType.FACE, GeometryType.FACE),
]
results = recognize_patterns(mates)
assert any(r.pattern is JointPattern.UNKNOWN for r in results)
def test_body_pair_normalization(self) -> None:
"""Mates with reversed body order should be grouped together."""
mates = [
_make_mate(0, MateType.CONCENTRIC, 1, 0, GeometryType.AXIS, GeometryType.AXIS),
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
]
results = recognize_patterns(mates)
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
assert len(hinge_matches) >= 1
assert hinge_matches[0].confidence == 1.0

View File

@@ -0,0 +1,329 @@
"""Tests for solver.mates.primitives -- mate type definitions."""
from __future__ import annotations
from typing import ClassVar
import numpy as np
import pytest
from solver.mates.primitives import (
GeometryRef,
GeometryType,
Mate,
MateType,
dof_removed,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_ref(
body_id: int,
geom_type: GeometryType,
*,
geometry_id: str = "Geom001",
origin: np.ndarray | None = None,
direction: np.ndarray | None = None,
) -> GeometryRef:
"""Factory for GeometryRef with sensible defaults."""
if origin is None:
origin = np.zeros(3)
if direction is None and geom_type in {
GeometryType.FACE,
GeometryType.AXIS,
GeometryType.PLANE,
}:
direction = np.array([0.0, 0.0, 1.0])
return GeometryRef(
body_id=body_id,
geometry_type=geom_type,
geometry_id=geometry_id,
origin=origin,
direction=direction,
)
# ---------------------------------------------------------------------------
# MateType
# ---------------------------------------------------------------------------
class TestMateType:
"""MateType enum construction and DOF values."""
EXPECTED_DOF: ClassVar[dict[str, int]] = {
"COINCIDENT": 3,
"CONCENTRIC": 2,
"PARALLEL": 2,
"PERPENDICULAR": 1,
"TANGENT": 1,
"DISTANCE": 1,
"ANGLE": 1,
"LOCK": 6,
}
def test_member_count(self) -> None:
assert len(MateType) == 8
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
def test_default_dof_values(self, name: str, dof: int) -> None:
assert MateType[name].default_dof == dof
def test_value_is_tuple(self) -> None:
assert MateType.COINCIDENT.value == (0, 3)
assert MateType.COINCIDENT.default_dof == 3
def test_access_by_name(self) -> None:
assert MateType["LOCK"] is MateType.LOCK
def test_no_alias_collision(self) -> None:
ordinals = [m.value[0] for m in MateType]
assert len(ordinals) == len(set(ordinals))
# ---------------------------------------------------------------------------
# GeometryType
# ---------------------------------------------------------------------------
class TestGeometryType:
"""GeometryType enum."""
def test_member_count(self) -> None:
assert len(GeometryType) == 5
def test_string_values(self) -> None:
for gt in GeometryType:
assert isinstance(gt.value, str)
assert gt.value == gt.name.lower()
def test_access_by_name(self) -> None:
assert GeometryType["FACE"] is GeometryType.FACE
# ---------------------------------------------------------------------------
# GeometryRef
# ---------------------------------------------------------------------------
class TestGeometryRef:
"""GeometryRef dataclass."""
def test_construction(self) -> None:
ref = _make_ref(0, GeometryType.AXIS, geometry_id="Axis001")
assert ref.body_id == 0
assert ref.geometry_type is GeometryType.AXIS
assert ref.geometry_id == "Axis001"
np.testing.assert_array_equal(ref.origin, np.zeros(3))
assert ref.direction is not None
def test_default_direction_none(self) -> None:
ref = GeometryRef(
body_id=0,
geometry_type=GeometryType.POINT,
geometry_id="Point001",
)
assert ref.direction is None
def test_to_dict_round_trip(self) -> None:
ref = _make_ref(
1,
GeometryType.FACE,
origin=np.array([1.0, 2.0, 3.0]),
direction=np.array([0.0, 1.0, 0.0]),
)
d = ref.to_dict()
restored = GeometryRef.from_dict(d)
assert restored.body_id == ref.body_id
assert restored.geometry_type is ref.geometry_type
assert restored.geometry_id == ref.geometry_id
np.testing.assert_array_almost_equal(restored.origin, ref.origin)
assert restored.direction is not None
np.testing.assert_array_almost_equal(restored.direction, ref.direction)
def test_to_dict_with_none_direction(self) -> None:
ref = GeometryRef(
body_id=2,
geometry_type=GeometryType.POINT,
geometry_id="Point002",
origin=np.array([5.0, 6.0, 7.0]),
)
d = ref.to_dict()
assert d["direction"] is None
restored = GeometryRef.from_dict(d)
assert restored.direction is None
# ---------------------------------------------------------------------------
# Mate
# ---------------------------------------------------------------------------
class TestMate:
"""Mate dataclass."""
def test_construction(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.FACE)
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
assert m.mate_id == 0
assert m.mate_type is MateType.COINCIDENT
def test_value_default_zero(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.FACE)
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
assert m.value == 0.0
def test_tolerance_default(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.FACE)
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
assert m.tolerance == 1e-6
def test_to_dict_round_trip(self) -> None:
ref_a = _make_ref(0, GeometryType.AXIS, origin=np.array([1.0, 0.0, 0.0]))
ref_b = _make_ref(1, GeometryType.AXIS, origin=np.array([2.0, 0.0, 0.0]))
m = Mate(
mate_id=5,
mate_type=MateType.CONCENTRIC,
ref_a=ref_a,
ref_b=ref_b,
value=0.0,
tolerance=1e-8,
)
d = m.to_dict()
restored = Mate.from_dict(d)
assert restored.mate_id == m.mate_id
assert restored.mate_type is m.mate_type
assert restored.ref_a.body_id == m.ref_a.body_id
assert restored.ref_b.body_id == m.ref_b.body_id
assert restored.value == m.value
assert restored.tolerance == m.tolerance
def test_from_dict_missing_optional(self) -> None:
d = {
"mate_id": 1,
"mate_type": "DISTANCE",
"ref_a": _make_ref(0, GeometryType.POINT).to_dict(),
"ref_b": _make_ref(1, GeometryType.POINT).to_dict(),
}
m = Mate.from_dict(d)
assert m.value == 0.0
assert m.tolerance == 1e-6
# ---------------------------------------------------------------------------
# dof_removed
# ---------------------------------------------------------------------------
class TestDofRemoved:
"""Context-dependent DOF removal counts."""
def test_coincident_face_face(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.FACE)
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
def test_coincident_point_point(self) -> None:
ref_a = _make_ref(0, GeometryType.POINT)
ref_b = _make_ref(1, GeometryType.POINT)
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
def test_coincident_edge_edge(self) -> None:
ref_a = _make_ref(0, GeometryType.EDGE)
ref_b = _make_ref(1, GeometryType.EDGE)
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 2
def test_coincident_face_point(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.POINT)
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 1
def test_concentric_axis_axis(self) -> None:
ref_a = _make_ref(0, GeometryType.AXIS)
ref_b = _make_ref(1, GeometryType.AXIS)
assert dof_removed(MateType.CONCENTRIC, ref_a, ref_b) == 2
def test_lock_any(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.POINT)
assert dof_removed(MateType.LOCK, ref_a, ref_b) == 6
def test_distance_any(self) -> None:
ref_a = _make_ref(0, GeometryType.POINT)
ref_b = _make_ref(1, GeometryType.EDGE)
assert dof_removed(MateType.DISTANCE, ref_a, ref_b) == 1
def test_unknown_combo_uses_default(self) -> None:
"""Unlisted geometry combos fall back to default_dof."""
ref_a = _make_ref(0, GeometryType.EDGE)
ref_b = _make_ref(1, GeometryType.POINT)
result = dof_removed(MateType.COINCIDENT, ref_a, ref_b)
assert result == MateType.COINCIDENT.default_dof
# ---------------------------------------------------------------------------
# Mate.validate
# ---------------------------------------------------------------------------
class TestMateValidation:
"""Mate.validate() compatibility checks."""
def test_valid_concentric(self) -> None:
ref_a = _make_ref(0, GeometryType.AXIS)
ref_b = _make_ref(1, GeometryType.AXIS)
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
m.validate() # should not raise
def test_invalid_concentric_face(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.AXIS)
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
with pytest.raises(ValueError, match="CONCENTRIC"):
m.validate()
def test_valid_coincident_face_face(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(1, GeometryType.FACE)
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
m.validate() # should not raise
def test_invalid_self_mate(self) -> None:
ref_a = _make_ref(0, GeometryType.FACE)
ref_b = _make_ref(0, GeometryType.FACE, geometry_id="Face002")
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
with pytest.raises(ValueError, match="Self-mate"):
m.validate()
def test_invalid_parallel_point(self) -> None:
ref_a = _make_ref(0, GeometryType.POINT)
ref_b = _make_ref(1, GeometryType.AXIS)
m = Mate(mate_id=0, mate_type=MateType.PARALLEL, ref_a=ref_a, ref_b=ref_b)
with pytest.raises(ValueError, match="PARALLEL"):
m.validate()
def test_invalid_tangent_axis(self) -> None:
ref_a = _make_ref(0, GeometryType.AXIS)
ref_b = _make_ref(1, GeometryType.FACE)
m = Mate(mate_id=0, mate_type=MateType.TANGENT, ref_a=ref_a, ref_b=ref_b)
with pytest.raises(ValueError, match="TANGENT"):
m.validate()
def test_missing_direction_for_axis(self) -> None:
ref_a = GeometryRef(
body_id=0,
geometry_type=GeometryType.AXIS,
geometry_id="Axis001",
origin=np.zeros(3),
direction=None, # missing!
)
ref_b = _make_ref(1, GeometryType.AXIS)
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
with pytest.raises(ValueError, match="direction"):
m.validate()