Compare commits
1 Commits
public
...
feat/gnn-m
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fe41fa3b00 |
@@ -2,55 +2,24 @@ name: CI
|
|||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches: [main, public]
|
branches: [main]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [main, public]
|
branches: [main]
|
||||||
workflow_dispatch:
|
|
||||||
inputs:
|
|
||||||
run_datagen:
|
|
||||||
description: "Run dataset generation"
|
|
||||||
required: false
|
|
||||||
type: boolean
|
|
||||||
default: false
|
|
||||||
num_assemblies:
|
|
||||||
description: "Number of assemblies to generate"
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: "100000"
|
|
||||||
num_workers:
|
|
||||||
description: "Parallel workers for datagen"
|
|
||||||
required: false
|
|
||||||
type: string
|
|
||||||
default: "4"
|
|
||||||
|
|
||||||
env:
|
|
||||||
PIP_CACHE_DIR: /tmp/pip-cache-solver
|
|
||||||
TORCH_INDEX: https://download.pytorch.org/whl/cpu
|
|
||||||
VIRTUAL_ENV: /tmp/solver-venv
|
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Lint — fast, no torch required
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
lint:
|
lint:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
pip install ruff mypy
|
||||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
pip install -e ".[dev]" || pip install ruff mypy numpy
|
||||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
|
||||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
|
||||||
cd "$GITHUB_WORKSPACE"
|
|
||||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
|
||||||
|
|
||||||
- name: Set up venv
|
|
||||||
run: python3 -m venv $VIRTUAL_ENV
|
|
||||||
|
|
||||||
- name: Install lint tools
|
|
||||||
run: pip install --cache-dir $PIP_CACHE_DIR ruff
|
|
||||||
|
|
||||||
- name: Ruff check
|
- name: Ruff check
|
||||||
run: ruff check solver/ freecad/ tests/ scripts/
|
run: ruff check solver/ freecad/ tests/ scripts/
|
||||||
@@ -58,123 +27,39 @@ jobs:
|
|||||||
- name: Ruff format check
|
- name: Ruff format check
|
||||||
run: ruff format --check solver/ freecad/ tests/ scripts/
|
run: ruff format --check solver/ freecad/ tests/ scripts/
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Type check
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
type-check:
|
type-check:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- uses: actions/checkout@v4
|
||||||
run: |
|
|
||||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
|
||||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
|
||||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
|
||||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
|
||||||
cd "$GITHUB_WORKSPACE"
|
|
||||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
|
||||||
|
|
||||||
- name: Set up venv
|
- uses: actions/setup-python@v5
|
||||||
run: python3 -m venv $VIRTUAL_ENV
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
pip install mypy numpy
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
pip install --cache-dir $PIP_CACHE_DIR mypy numpy scipy
|
pip install torch-geometric
|
||||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[dev]"
|
pip install -e ".[dev]"
|
||||||
|
|
||||||
- name: Mypy
|
- name: Mypy
|
||||||
run: mypy solver/ freecad/
|
run: mypy solver/ freecad/
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Tests
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
test:
|
test:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
env:
|
|
||||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout
|
- uses: actions/checkout@v4
|
||||||
run: |
|
|
||||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
|
||||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
|
||||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
|
||||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
|
||||||
cd "$GITHUB_WORKSPACE"
|
|
||||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
|
||||||
|
|
||||||
- name: Set up venv
|
- uses: actions/setup-python@v5
|
||||||
run: python3 -m venv $VIRTUAL_ENV
|
with:
|
||||||
|
python-version: "3.11"
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: |
|
run: |
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
pip install torch --index-url https://download.pytorch.org/whl/cpu
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
pip install torch-geometric
|
||||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[train,dev]"
|
pip install -e ".[train,dev]"
|
||||||
|
|
||||||
- name: Run tests
|
- name: Run tests
|
||||||
run: pytest tests/ freecad/tests/ -v --tb=short
|
run: pytest tests/ freecad/tests/ -v --tb=short
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Dataset generation — manual trigger or on main/public push
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
datagen:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: >-
|
|
||||||
(github.event_name == 'workflow_dispatch' && inputs.run_datagen == true) ||
|
|
||||||
(github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/public'))
|
|
||||||
needs: [test]
|
|
||||||
env:
|
|
||||||
PATH: /tmp/solver-venv/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
|
|
||||||
steps:
|
|
||||||
- name: Checkout
|
|
||||||
run: |
|
|
||||||
git config --global --add safe.directory "$GITHUB_WORKSPACE"
|
|
||||||
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
|
||||||
"${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE" \
|
|
||||||
|| git clone --depth 1 "${GITHUB_SERVER_URL}/${GITHUB_REPOSITORY}.git" "$GITHUB_WORKSPACE"
|
|
||||||
cd "$GITHUB_WORKSPACE"
|
|
||||||
git checkout "$GITHUB_SHA" 2>/dev/null || true
|
|
||||||
|
|
||||||
- name: Set up venv
|
|
||||||
run: python3 -m venv $VIRTUAL_ENV
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: |
|
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch --index-url $TORCH_INDEX
|
|
||||||
pip install --cache-dir $PIP_CACHE_DIR torch-geometric
|
|
||||||
pip install --cache-dir $PIP_CACHE_DIR -e ".[train]"
|
|
||||||
|
|
||||||
- name: Generate dataset
|
|
||||||
run: |
|
|
||||||
NUM=${INPUTS_NUM_ASSEMBLIES:-100000}
|
|
||||||
WORKERS=${INPUTS_NUM_WORKERS:-4}
|
|
||||||
echo "Generating ${NUM} assemblies with ${WORKERS} workers"
|
|
||||||
python3 scripts/generate_synthetic.py \
|
|
||||||
--num-assemblies "${NUM}" \
|
|
||||||
--num-workers "${WORKERS}" \
|
|
||||||
--output-dir data/synthetic
|
|
||||||
env:
|
|
||||||
INPUTS_NUM_ASSEMBLIES: ${{ inputs.num_assemblies }}
|
|
||||||
INPUTS_NUM_WORKERS: ${{ inputs.num_workers }}
|
|
||||||
|
|
||||||
- name: Print summary
|
|
||||||
if: always()
|
|
||||||
run: |
|
|
||||||
echo "=== Dataset Generation Results ==="
|
|
||||||
if [ -f data/synthetic/stats.json ]; then
|
|
||||||
python3 -c "
|
|
||||||
import json
|
|
||||||
with open('data/synthetic/stats.json') as f:
|
|
||||||
s = json.load(f)
|
|
||||||
print(f'Total examples: {s[\"total_examples\"]}')
|
|
||||||
print(f'Classification: {json.dumps(s[\"classification_distribution\"], indent=2)}')
|
|
||||||
print(f'Rigid: {s[\"rigidity\"][\"rigid_fraction\"]*100:.1f}%')
|
|
||||||
print(f'Degeneracy: {s[\"geometric_degeneracy\"][\"fraction_with_degeneracy\"]*100:.1f}%')
|
|
||||||
"
|
|
||||||
else
|
|
||||||
echo "stats.json not found — generation may have failed"
|
|
||||||
ls -la data/synthetic/ 2>/dev/null || echo "output dir missing"
|
|
||||||
fi
|
|
||||||
|
|||||||
@@ -16,9 +16,9 @@ heads:
|
|||||||
hidden_dim: 64
|
hidden_dim: 64
|
||||||
graph_classification:
|
graph_classification:
|
||||||
enabled: true
|
enabled: true
|
||||||
num_classes: 4 # rigid, under, over, mixed
|
num_classes: 4 # rigid, under, over, mixed
|
||||||
joint_type:
|
joint_type:
|
||||||
enabled: true
|
enabled: true
|
||||||
num_classes: 12
|
num_classes: 11
|
||||||
dof_regression:
|
dof_regression:
|
||||||
enabled: true
|
enabled: true
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ heads:
|
|||||||
num_classes: 4
|
num_classes: 4
|
||||||
joint_type:
|
joint_type:
|
||||||
enabled: true
|
enabled: true
|
||||||
num_classes: 12
|
num_classes: 11
|
||||||
dof_regression:
|
dof_regression:
|
||||||
enabled: true
|
enabled: true
|
||||||
dof_tracking:
|
dof_tracking:
|
||||||
|
|||||||
@@ -877,6 +877,9 @@ class SyntheticAssemblyGenerator:
|
|||||||
"body_b": j.body_b,
|
"body_b": j.body_b,
|
||||||
"type": j.joint_type.name,
|
"type": j.joint_type.name,
|
||||||
"axis": j.axis.tolist(),
|
"axis": j.axis.tolist(),
|
||||||
|
"anchor_a": j.anchor_a.tolist(),
|
||||||
|
"anchor_b": j.anchor_b.tolist(),
|
||||||
|
"pitch": j.pitch,
|
||||||
}
|
}
|
||||||
for j in joints
|
for j in joints
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -1,47 +0,0 @@
|
|||||||
"""Mate-level constraint types for assembly analysis."""
|
|
||||||
|
|
||||||
from solver.mates.conversion import (
|
|
||||||
MateAnalysisResult,
|
|
||||||
analyze_mate_assembly,
|
|
||||||
convert_mates_to_joints,
|
|
||||||
)
|
|
||||||
from solver.mates.generator import (
|
|
||||||
SyntheticMateGenerator,
|
|
||||||
generate_mate_training_batch,
|
|
||||||
)
|
|
||||||
from solver.mates.labeling import (
|
|
||||||
MateAssemblyLabels,
|
|
||||||
MateLabel,
|
|
||||||
label_mate_assembly,
|
|
||||||
)
|
|
||||||
from solver.mates.patterns import (
|
|
||||||
JointPattern,
|
|
||||||
PatternMatch,
|
|
||||||
recognize_patterns,
|
|
||||||
)
|
|
||||||
from solver.mates.primitives import (
|
|
||||||
GeometryRef,
|
|
||||||
GeometryType,
|
|
||||||
Mate,
|
|
||||||
MateType,
|
|
||||||
dof_removed,
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"GeometryRef",
|
|
||||||
"GeometryType",
|
|
||||||
"JointPattern",
|
|
||||||
"Mate",
|
|
||||||
"MateAnalysisResult",
|
|
||||||
"MateAssemblyLabels",
|
|
||||||
"MateLabel",
|
|
||||||
"MateType",
|
|
||||||
"PatternMatch",
|
|
||||||
"SyntheticMateGenerator",
|
|
||||||
"analyze_mate_assembly",
|
|
||||||
"convert_mates_to_joints",
|
|
||||||
"dof_removed",
|
|
||||||
"generate_mate_training_batch",
|
|
||||||
"label_mate_assembly",
|
|
||||||
"recognize_patterns",
|
|
||||||
]
|
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
"""Mate-to-joint conversion and assembly analysis.
|
|
||||||
|
|
||||||
Bridges the mate-level constraint representation to the existing
|
|
||||||
joint-based analysis pipeline. Converts recognized mate patterns
|
|
||||||
to Joint objects, then runs the pebble game and Jacobian analysis,
|
|
||||||
maintaining bidirectional traceability between mates and joints.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from solver.datagen.labeling import AssemblyLabels, label_assembly
|
|
||||||
from solver.datagen.types import (
|
|
||||||
ConstraintAnalysis,
|
|
||||||
Joint,
|
|
||||||
JointType,
|
|
||||||
RigidBody,
|
|
||||||
)
|
|
||||||
from solver.mates.patterns import PatternMatch, recognize_patterns
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from solver.mates.primitives import Mate
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MateAnalysisResult",
|
|
||||||
"analyze_mate_assembly",
|
|
||||||
"convert_mates_to_joints",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Result dataclass
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MateAnalysisResult:
|
|
||||||
"""Combined result of mate-based assembly analysis.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
patterns: Recognized joint patterns from mate grouping.
|
|
||||||
joints: Joint objects produced by conversion.
|
|
||||||
mate_to_joint: Mapping from mate_id to list of joint_ids.
|
|
||||||
joint_to_mates: Mapping from joint_id to list of mate_ids.
|
|
||||||
analysis: Constraint analysis from pebble game + Jacobian.
|
|
||||||
labels: Full ground truth labels from label_assembly.
|
|
||||||
"""
|
|
||||||
|
|
||||||
patterns: list[PatternMatch]
|
|
||||||
joints: list[Joint]
|
|
||||||
mate_to_joint: dict[int, list[int]] = field(default_factory=dict)
|
|
||||||
joint_to_mates: dict[int, list[int]] = field(default_factory=dict)
|
|
||||||
analysis: ConstraintAnalysis | None = None
|
|
||||||
labels: AssemblyLabels | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"patterns": [p.to_dict() for p in self.patterns],
|
|
||||||
"joints": [
|
|
||||||
{
|
|
||||||
"joint_id": j.joint_id,
|
|
||||||
"body_a": j.body_a,
|
|
||||||
"body_b": j.body_b,
|
|
||||||
"joint_type": j.joint_type.name,
|
|
||||||
}
|
|
||||||
for j in self.joints
|
|
||||||
],
|
|
||||||
"mate_to_joint": self.mate_to_joint,
|
|
||||||
"joint_to_mates": self.joint_to_mates,
|
|
||||||
"labels": self.labels.to_dict() if self.labels else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pattern-to-JointType mapping
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Maps (JointPattern value) to JointType for known patterns.
|
|
||||||
# Used by convert_mates_to_joints when a full pattern is recognized.
|
|
||||||
_PATTERN_JOINT_MAP: dict[str, JointType] = {
|
|
||||||
"hinge": JointType.REVOLUTE,
|
|
||||||
"slider": JointType.SLIDER,
|
|
||||||
"cylinder": JointType.CYLINDRICAL,
|
|
||||||
"ball": JointType.BALL,
|
|
||||||
"planar": JointType.PLANAR,
|
|
||||||
"fixed": JointType.FIXED,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Fallback mapping for individual mate types when no pattern is recognized.
|
|
||||||
_MATE_JOINT_FALLBACK: dict[str, JointType] = {
|
|
||||||
"COINCIDENT": JointType.PLANAR,
|
|
||||||
"CONCENTRIC": JointType.CYLINDRICAL,
|
|
||||||
"PARALLEL": JointType.PARALLEL,
|
|
||||||
"PERPENDICULAR": JointType.PERPENDICULAR,
|
|
||||||
"TANGENT": JointType.DISTANCE,
|
|
||||||
"DISTANCE": JointType.DISTANCE,
|
|
||||||
"ANGLE": JointType.PERPENDICULAR,
|
|
||||||
"LOCK": JointType.FIXED,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Conversion
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _compute_joint_params(
|
|
||||||
pattern: PatternMatch,
|
|
||||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
|
||||||
"""Extract anchor and axis from pattern mates.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(anchor_a, anchor_b, axis)
|
|
||||||
"""
|
|
||||||
anchor_a = np.zeros(3)
|
|
||||||
anchor_b = np.zeros(3)
|
|
||||||
axis = np.array([0.0, 0.0, 1.0])
|
|
||||||
|
|
||||||
for mate in pattern.mates:
|
|
||||||
ref_a = mate.ref_a
|
|
||||||
ref_b = mate.ref_b
|
|
||||||
anchor_a = ref_a.origin.copy()
|
|
||||||
anchor_b = ref_b.origin.copy()
|
|
||||||
if ref_a.direction is not None:
|
|
||||||
axis = ref_a.direction.copy()
|
|
||||||
break
|
|
||||||
|
|
||||||
return anchor_a, anchor_b, axis
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_single_mate(
|
|
||||||
mate: Mate,
|
|
||||||
joint_id: int,
|
|
||||||
) -> Joint:
|
|
||||||
"""Convert a single unmatched mate to a Joint."""
|
|
||||||
joint_type = _MATE_JOINT_FALLBACK.get(mate.mate_type.name, JointType.DISTANCE)
|
|
||||||
|
|
||||||
anchor_a = mate.ref_a.origin.copy()
|
|
||||||
anchor_b = mate.ref_b.origin.copy()
|
|
||||||
axis = np.array([0.0, 0.0, 1.0])
|
|
||||||
if mate.ref_a.direction is not None:
|
|
||||||
axis = mate.ref_a.direction.copy()
|
|
||||||
|
|
||||||
return Joint(
|
|
||||||
joint_id=joint_id,
|
|
||||||
body_a=mate.ref_a.body_id,
|
|
||||||
body_b=mate.ref_b.body_id,
|
|
||||||
joint_type=joint_type,
|
|
||||||
anchor_a=anchor_a,
|
|
||||||
anchor_b=anchor_b,
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def convert_mates_to_joints(
|
|
||||||
mates: list[Mate],
|
|
||||||
bodies: list[RigidBody] | None = None,
|
|
||||||
) -> tuple[list[Joint], dict[int, list[int]], dict[int, list[int]]]:
|
|
||||||
"""Convert mates to Joint objects via pattern recognition.
|
|
||||||
|
|
||||||
For each body pair:
|
|
||||||
- If mates form a recognized pattern, emit the equivalent joint.
|
|
||||||
- Otherwise, emit individual joints for each unmatched mate.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mates: Mate constraints to convert.
|
|
||||||
bodies: Optional body list (unused currently, reserved for
|
|
||||||
future geometry lookups).
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(joints, mate_to_joint, joint_to_mates) tuple.
|
|
||||||
"""
|
|
||||||
if not mates:
|
|
||||||
return [], {}, {}
|
|
||||||
|
|
||||||
patterns = recognize_patterns(mates)
|
|
||||||
joints: list[Joint] = []
|
|
||||||
mate_to_joint: dict[int, list[int]] = {}
|
|
||||||
joint_to_mates: dict[int, list[int]] = {}
|
|
||||||
|
|
||||||
# Track which mates have been consumed by full-confidence patterns
|
|
||||||
consumed_mate_ids: set[int] = set()
|
|
||||||
next_joint_id = 0
|
|
||||||
|
|
||||||
# First pass: emit joints for full-confidence patterns
|
|
||||||
for pattern in patterns:
|
|
||||||
if pattern.confidence < 1.0:
|
|
||||||
continue
|
|
||||||
if pattern.pattern.value not in _PATTERN_JOINT_MAP:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Check if any of these mates were already consumed
|
|
||||||
mate_ids = [m.mate_id for m in pattern.mates]
|
|
||||||
if any(mid in consumed_mate_ids for mid in mate_ids):
|
|
||||||
continue
|
|
||||||
|
|
||||||
joint_type = _PATTERN_JOINT_MAP[pattern.pattern.value]
|
|
||||||
anchor_a, anchor_b, axis = _compute_joint_params(pattern)
|
|
||||||
|
|
||||||
joint = Joint(
|
|
||||||
joint_id=next_joint_id,
|
|
||||||
body_a=pattern.body_a,
|
|
||||||
body_b=pattern.body_b,
|
|
||||||
joint_type=joint_type,
|
|
||||||
anchor_a=anchor_a,
|
|
||||||
anchor_b=anchor_b,
|
|
||||||
axis=axis,
|
|
||||||
)
|
|
||||||
joints.append(joint)
|
|
||||||
|
|
||||||
joint_to_mates[next_joint_id] = mate_ids
|
|
||||||
for mid in mate_ids:
|
|
||||||
mate_to_joint.setdefault(mid, []).append(next_joint_id)
|
|
||||||
consumed_mate_ids.add(mid)
|
|
||||||
|
|
||||||
next_joint_id += 1
|
|
||||||
|
|
||||||
# Second pass: emit individual joints for unconsumed mates
|
|
||||||
for mate in mates:
|
|
||||||
if mate.mate_id in consumed_mate_ids:
|
|
||||||
continue
|
|
||||||
|
|
||||||
joint = _convert_single_mate(mate, next_joint_id)
|
|
||||||
joints.append(joint)
|
|
||||||
|
|
||||||
joint_to_mates[next_joint_id] = [mate.mate_id]
|
|
||||||
mate_to_joint.setdefault(mate.mate_id, []).append(next_joint_id)
|
|
||||||
|
|
||||||
next_joint_id += 1
|
|
||||||
|
|
||||||
return joints, mate_to_joint, joint_to_mates
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Full analysis pipeline
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def analyze_mate_assembly(
|
|
||||||
bodies: list[RigidBody],
|
|
||||||
mates: list[Mate],
|
|
||||||
ground_body: int | None = None,
|
|
||||||
) -> MateAnalysisResult:
|
|
||||||
"""Run the full analysis pipeline on a mate-based assembly.
|
|
||||||
|
|
||||||
Orchestrates: recognize_patterns -> convert_mates_to_joints ->
|
|
||||||
label_assembly, returning a combined result with full traceability.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bodies: Rigid bodies in the assembly.
|
|
||||||
mates: Mate constraints between the bodies.
|
|
||||||
ground_body: If set, this body is fixed to the world.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MateAnalysisResult with patterns, joints, mappings, and labels.
|
|
||||||
"""
|
|
||||||
patterns = recognize_patterns(mates)
|
|
||||||
joints, mate_to_joint, joint_to_mates = convert_mates_to_joints(mates, bodies)
|
|
||||||
|
|
||||||
labels = label_assembly(bodies, joints, ground_body)
|
|
||||||
|
|
||||||
return MateAnalysisResult(
|
|
||||||
patterns=patterns,
|
|
||||||
joints=joints,
|
|
||||||
mate_to_joint=mate_to_joint,
|
|
||||||
joint_to_mates=joint_to_mates,
|
|
||||||
analysis=labels.analysis,
|
|
||||||
labels=labels,
|
|
||||||
)
|
|
||||||
@@ -1,315 +0,0 @@
|
|||||||
"""Mate-based synthetic assembly generator.
|
|
||||||
|
|
||||||
Wraps SyntheticAssemblyGenerator to produce mate-level training data.
|
|
||||||
Generates joint-based assemblies via the existing generator, then
|
|
||||||
reverse-maps joints to plausible mate combinations. Supports noise
|
|
||||||
injection (redundant, missing, incompatible mates) for robust training.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
|
||||||
from solver.datagen.types import Joint, JointType, RigidBody
|
|
||||||
from solver.mates.conversion import MateAnalysisResult, analyze_mate_assembly
|
|
||||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"SyntheticMateGenerator",
|
|
||||||
"generate_mate_training_batch",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Reverse mapping: JointType -> list of (MateType, geom_a, geom_b) combos
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class _MateSpec:
|
|
||||||
"""Specification for a mate to generate from a joint."""
|
|
||||||
|
|
||||||
mate_type: MateType
|
|
||||||
geom_a: GeometryType
|
|
||||||
geom_b: GeometryType
|
|
||||||
|
|
||||||
|
|
||||||
_JOINT_TO_MATES: dict[JointType, list[_MateSpec]] = {
|
|
||||||
JointType.REVOLUTE: [
|
|
||||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
],
|
|
||||||
JointType.CYLINDRICAL: [
|
|
||||||
_MateSpec(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
],
|
|
||||||
JointType.BALL: [
|
|
||||||
_MateSpec(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),
|
|
||||||
],
|
|
||||||
JointType.FIXED: [
|
|
||||||
_MateSpec(MateType.LOCK, GeometryType.FACE, GeometryType.FACE),
|
|
||||||
],
|
|
||||||
JointType.SLIDER: [
|
|
||||||
_MateSpec(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
_MateSpec(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
],
|
|
||||||
JointType.PLANAR: [
|
|
||||||
_MateSpec(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Generator
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class SyntheticMateGenerator:
|
|
||||||
"""Generates mate-based assemblies for training data.
|
|
||||||
|
|
||||||
Wraps SyntheticAssemblyGenerator to produce joint-based assemblies,
|
|
||||||
then reverse-maps each joint to a plausible set of mate constraints.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
seed: Random seed for reproducibility.
|
|
||||||
redundant_prob: Probability of injecting a redundant mate per joint.
|
|
||||||
missing_prob: Probability of dropping a mate from a multi-mate pattern.
|
|
||||||
incompatible_prob: Probability of injecting a mate with wrong geometry.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
seed: int = 42,
|
|
||||||
*,
|
|
||||||
redundant_prob: float = 0.0,
|
|
||||||
missing_prob: float = 0.0,
|
|
||||||
incompatible_prob: float = 0.0,
|
|
||||||
) -> None:
|
|
||||||
self._joint_gen = SyntheticAssemblyGenerator(seed=seed)
|
|
||||||
self._rng = np.random.default_rng(seed)
|
|
||||||
self.redundant_prob = redundant_prob
|
|
||||||
self.missing_prob = missing_prob
|
|
||||||
self.incompatible_prob = incompatible_prob
|
|
||||||
|
|
||||||
def _make_geometry_ref(
|
|
||||||
self,
|
|
||||||
body_id: int,
|
|
||||||
geom_type: GeometryType,
|
|
||||||
joint: Joint,
|
|
||||||
*,
|
|
||||||
is_ref_a: bool = True,
|
|
||||||
) -> GeometryRef:
|
|
||||||
"""Create a GeometryRef from joint geometry.
|
|
||||||
|
|
||||||
Uses joint anchor, axis, and body_id to produce a ref
|
|
||||||
with realistic geometry for the given type.
|
|
||||||
"""
|
|
||||||
origin = joint.anchor_a if is_ref_a else joint.anchor_b
|
|
||||||
|
|
||||||
direction: np.ndarray | None = None
|
|
||||||
if geom_type in {GeometryType.AXIS, GeometryType.PLANE, GeometryType.FACE}:
|
|
||||||
direction = joint.axis.copy()
|
|
||||||
|
|
||||||
geom_id = f"{geom_type.value.capitalize()}001"
|
|
||||||
|
|
||||||
return GeometryRef(
|
|
||||||
body_id=body_id,
|
|
||||||
geometry_type=geom_type,
|
|
||||||
geometry_id=geom_id,
|
|
||||||
origin=origin.copy(),
|
|
||||||
direction=direction,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _reverse_map_joint(
|
|
||||||
self,
|
|
||||||
joint: Joint,
|
|
||||||
next_mate_id: int,
|
|
||||||
) -> list[Mate]:
|
|
||||||
"""Convert a joint to its mate representation."""
|
|
||||||
specs = _JOINT_TO_MATES.get(joint.joint_type, [])
|
|
||||||
if not specs:
|
|
||||||
# Fallback: emit a single DISTANCE mate
|
|
||||||
specs = [_MateSpec(MateType.DISTANCE, GeometryType.POINT, GeometryType.POINT)]
|
|
||||||
|
|
||||||
mates: list[Mate] = []
|
|
||||||
for spec in specs:
|
|
||||||
ref_a = self._make_geometry_ref(joint.body_a, spec.geom_a, joint, is_ref_a=True)
|
|
||||||
ref_b = self._make_geometry_ref(joint.body_b, spec.geom_b, joint, is_ref_a=False)
|
|
||||||
mates.append(
|
|
||||||
Mate(
|
|
||||||
mate_id=next_mate_id + len(mates),
|
|
||||||
mate_type=spec.mate_type,
|
|
||||||
ref_a=ref_a,
|
|
||||||
ref_b=ref_b,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return mates
|
|
||||||
|
|
||||||
def _inject_noise(
|
|
||||||
self,
|
|
||||||
mates: list[Mate],
|
|
||||||
next_mate_id: int,
|
|
||||||
) -> list[Mate]:
|
|
||||||
"""Apply noise injection to the mate list.
|
|
||||||
|
|
||||||
Modifies the list in-place and may add new mates.
|
|
||||||
Returns the (possibly extended) list.
|
|
||||||
"""
|
|
||||||
result = list(mates)
|
|
||||||
extra: list[Mate] = []
|
|
||||||
|
|
||||||
for mate in mates:
|
|
||||||
# Redundant: duplicate a mate
|
|
||||||
if self._rng.random() < self.redundant_prob:
|
|
||||||
dup = Mate(
|
|
||||||
mate_id=next_mate_id + len(extra),
|
|
||||||
mate_type=mate.mate_type,
|
|
||||||
ref_a=mate.ref_a,
|
|
||||||
ref_b=mate.ref_b,
|
|
||||||
value=mate.value,
|
|
||||||
tolerance=mate.tolerance,
|
|
||||||
)
|
|
||||||
extra.append(dup)
|
|
||||||
|
|
||||||
# Incompatible: wrong geometry type
|
|
||||||
if self._rng.random() < self.incompatible_prob:
|
|
||||||
bad_geom = GeometryType.POINT
|
|
||||||
bad_ref = GeometryRef(
|
|
||||||
body_id=mate.ref_a.body_id,
|
|
||||||
geometry_type=bad_geom,
|
|
||||||
geometry_id="BadGeom001",
|
|
||||||
origin=mate.ref_a.origin.copy(),
|
|
||||||
direction=None,
|
|
||||||
)
|
|
||||||
extra.append(
|
|
||||||
Mate(
|
|
||||||
mate_id=next_mate_id + len(extra),
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=bad_ref,
|
|
||||||
ref_b=mate.ref_b,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
result.extend(extra)
|
|
||||||
|
|
||||||
# Missing: drop mates from multi-mate patterns (only if > 1 mate
|
|
||||||
# for same body pair)
|
|
||||||
if self.missing_prob > 0:
|
|
||||||
filtered: list[Mate] = []
|
|
||||||
for mate in result:
|
|
||||||
if self._rng.random() < self.missing_prob:
|
|
||||||
continue
|
|
||||||
filtered.append(mate)
|
|
||||||
# Ensure at least one mate remains
|
|
||||||
if not filtered and result:
|
|
||||||
filtered = [result[0]]
|
|
||||||
result = filtered
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
n_bodies: int = 4,
|
|
||||||
*,
|
|
||||||
grounded: bool = False,
|
|
||||||
) -> tuple[list[RigidBody], list[Mate], MateAnalysisResult]:
|
|
||||||
"""Generate a mate-based assembly.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
n_bodies: Number of rigid bodies.
|
|
||||||
grounded: Whether to ground the first body.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(bodies, mates, analysis_result) tuple.
|
|
||||||
"""
|
|
||||||
bodies, joints, _analysis = self._joint_gen.generate_chain_assembly(
|
|
||||||
n_bodies,
|
|
||||||
joint_type=JointType.REVOLUTE,
|
|
||||||
grounded=grounded,
|
|
||||||
)
|
|
||||||
|
|
||||||
mates: list[Mate] = []
|
|
||||||
next_id = 0
|
|
||||||
for joint in joints:
|
|
||||||
joint_mates = self._reverse_map_joint(joint, next_id)
|
|
||||||
mates.extend(joint_mates)
|
|
||||||
next_id += len(joint_mates)
|
|
||||||
|
|
||||||
# Apply noise
|
|
||||||
mates = self._inject_noise(mates, next_id)
|
|
||||||
|
|
||||||
ground_body = bodies[0].body_id if grounded else None
|
|
||||||
result = analyze_mate_assembly(bodies, mates, ground_body)
|
|
||||||
|
|
||||||
return bodies, mates, result
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Batch generation
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def generate_mate_training_batch(
|
|
||||||
batch_size: int = 100,
|
|
||||||
n_bodies_range: tuple[int, int] = (3, 8),
|
|
||||||
seed: int = 42,
|
|
||||||
*,
|
|
||||||
redundant_prob: float = 0.0,
|
|
||||||
missing_prob: float = 0.0,
|
|
||||||
incompatible_prob: float = 0.0,
|
|
||||||
grounded_ratio: float = 1.0,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Produce a batch of mate-level training examples.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
batch_size: Number of assemblies to generate.
|
|
||||||
n_bodies_range: (min, max_exclusive) body count.
|
|
||||||
seed: Random seed.
|
|
||||||
redundant_prob: Probability of redundant mate injection.
|
|
||||||
missing_prob: Probability of missing mate injection.
|
|
||||||
incompatible_prob: Probability of incompatible mate injection.
|
|
||||||
grounded_ratio: Fraction of assemblies that are grounded.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of dicts with bodies, mates, patterns, and labels.
|
|
||||||
"""
|
|
||||||
rng = np.random.default_rng(seed)
|
|
||||||
examples: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
for i in range(batch_size):
|
|
||||||
gen = SyntheticMateGenerator(
|
|
||||||
seed=seed + i,
|
|
||||||
redundant_prob=redundant_prob,
|
|
||||||
missing_prob=missing_prob,
|
|
||||||
incompatible_prob=incompatible_prob,
|
|
||||||
)
|
|
||||||
n = int(rng.integers(*n_bodies_range))
|
|
||||||
grounded = bool(rng.random() < grounded_ratio)
|
|
||||||
|
|
||||||
bodies, mates, result = gen.generate(n, grounded=grounded)
|
|
||||||
|
|
||||||
examples.append(
|
|
||||||
{
|
|
||||||
"bodies": [
|
|
||||||
{
|
|
||||||
"body_id": b.body_id,
|
|
||||||
"position": b.position.tolist(),
|
|
||||||
}
|
|
||||||
for b in bodies
|
|
||||||
],
|
|
||||||
"mates": [m.to_dict() for m in mates],
|
|
||||||
"patterns": [p.to_dict() for p in result.patterns],
|
|
||||||
"labels": result.labels.to_dict() if result.labels else None,
|
|
||||||
"n_bodies": len(bodies),
|
|
||||||
"n_mates": len(mates),
|
|
||||||
"n_joints": len(result.joints),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
return examples
|
|
||||||
@@ -1,224 +0,0 @@
|
|||||||
"""Mate-level ground truth labels for assembly analysis.
|
|
||||||
|
|
||||||
Back-attributes joint-level independence results to originating mates
|
|
||||||
via the mate-to-joint mapping from conversion.py. Produces per-mate
|
|
||||||
labels indicating whether each mate is independent, redundant, or
|
|
||||||
degenerate.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from solver.mates.conversion import analyze_mate_assembly
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from solver.datagen.labeling import AssemblyLabel
|
|
||||||
from solver.datagen.types import ConstraintAnalysis, RigidBody
|
|
||||||
from solver.mates.conversion import MateAnalysisResult
|
|
||||||
from solver.mates.patterns import JointPattern, PatternMatch
|
|
||||||
from solver.mates.primitives import Mate
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"MateAssemblyLabels",
|
|
||||||
"MateLabel",
|
|
||||||
"label_mate_assembly",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Label dataclasses
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MateLabel:
|
|
||||||
"""Per-mate ground truth label.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
mate_id: The mate this label refers to.
|
|
||||||
is_independent: Contributes non-redundant DOF removal.
|
|
||||||
is_redundant: Fully redundant (removable without DOF change).
|
|
||||||
is_degenerate: Combinatorially independent but geometrically dependent.
|
|
||||||
pattern: Which joint pattern this mate belongs to, if any.
|
|
||||||
issue: Detected issue type, if any.
|
|
||||||
"""
|
|
||||||
|
|
||||||
mate_id: int
|
|
||||||
is_independent: bool = True
|
|
||||||
is_redundant: bool = False
|
|
||||||
is_degenerate: bool = False
|
|
||||||
pattern: JointPattern | None = None
|
|
||||||
issue: str | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"mate_id": self.mate_id,
|
|
||||||
"is_independent": self.is_independent,
|
|
||||||
"is_redundant": self.is_redundant,
|
|
||||||
"is_degenerate": self.is_degenerate,
|
|
||||||
"pattern": self.pattern.value if self.pattern else None,
|
|
||||||
"issue": self.issue,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MateAssemblyLabels:
|
|
||||||
"""Complete mate-level ground truth labels for an assembly.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
per_mate: Per-mate labels.
|
|
||||||
patterns: Recognized joint patterns.
|
|
||||||
assembly: Assembly-wide summary label.
|
|
||||||
analysis: Constraint analysis from pebble game + Jacobian.
|
|
||||||
"""
|
|
||||||
|
|
||||||
per_mate: list[MateLabel]
|
|
||||||
patterns: list[PatternMatch]
|
|
||||||
assembly: AssemblyLabel
|
|
||||||
analysis: ConstraintAnalysis
|
|
||||||
mate_analysis: MateAnalysisResult | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"per_mate": [ml.to_dict() for ml in self.per_mate],
|
|
||||||
"patterns": [p.to_dict() for p in self.patterns],
|
|
||||||
"assembly": {
|
|
||||||
"classification": self.assembly.classification,
|
|
||||||
"total_dof": self.assembly.total_dof,
|
|
||||||
"redundant_count": self.assembly.redundant_count,
|
|
||||||
"is_rigid": self.assembly.is_rigid,
|
|
||||||
"is_minimally_rigid": self.assembly.is_minimally_rigid,
|
|
||||||
"has_degeneracy": self.assembly.has_degeneracy,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Labeling logic
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _build_mate_pattern_map(
|
|
||||||
patterns: list[PatternMatch],
|
|
||||||
) -> dict[int, JointPattern]:
|
|
||||||
"""Map mate_ids to the pattern they belong to (best match)."""
|
|
||||||
result: dict[int, JointPattern] = {}
|
|
||||||
# Sort by confidence descending so best matches win
|
|
||||||
sorted_patterns = sorted(patterns, key=lambda p: -p.confidence)
|
|
||||||
for pm in sorted_patterns:
|
|
||||||
if pm.confidence < 1.0:
|
|
||||||
continue
|
|
||||||
for mate in pm.mates:
|
|
||||||
if mate.mate_id not in result:
|
|
||||||
result[mate.mate_id] = pm.pattern
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def label_mate_assembly(
|
|
||||||
bodies: list[RigidBody],
|
|
||||||
mates: list[Mate],
|
|
||||||
ground_body: int | None = None,
|
|
||||||
) -> MateAssemblyLabels:
|
|
||||||
"""Produce mate-level ground truth labels for an assembly.
|
|
||||||
|
|
||||||
Runs analyze_mate_assembly() internally, then back-attributes
|
|
||||||
joint-level independence to originating mates via the mate_to_joint
|
|
||||||
mapping.
|
|
||||||
|
|
||||||
A mate is:
|
|
||||||
- **redundant** if ALL joints it contributes to are fully redundant
|
|
||||||
- **degenerate** if any joint it contributes to is geometrically
|
|
||||||
dependent but combinatorially independent
|
|
||||||
- **independent** otherwise
|
|
||||||
|
|
||||||
Args:
|
|
||||||
bodies: Rigid bodies in the assembly.
|
|
||||||
mates: Mate constraints between the bodies.
|
|
||||||
ground_body: If set, this body is fixed to the world.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
MateAssemblyLabels with per-mate labels and assembly summary.
|
|
||||||
"""
|
|
||||||
mate_result = analyze_mate_assembly(bodies, mates, ground_body)
|
|
||||||
|
|
||||||
# Build per-joint redundancy from labels
|
|
||||||
joint_redundant: dict[int, bool] = {}
|
|
||||||
joint_degenerate: dict[int, bool] = {}
|
|
||||||
|
|
||||||
if mate_result.labels is not None:
|
|
||||||
for jl in mate_result.labels.per_joint:
|
|
||||||
# A joint is fully redundant if all its constraints are redundant
|
|
||||||
joint_redundant[jl.joint_id] = jl.redundant_count == jl.total and jl.total > 0
|
|
||||||
# Joint is degenerate if it has more independent constraints
|
|
||||||
# than Jacobian rank would suggest (geometric degeneracy)
|
|
||||||
joint_degenerate[jl.joint_id] = False
|
|
||||||
|
|
||||||
# Check for geometric degeneracy via per-constraint labels
|
|
||||||
for cl in mate_result.labels.per_constraint:
|
|
||||||
if cl.pebble_independent and not cl.jacobian_independent:
|
|
||||||
joint_degenerate[cl.joint_id] = True
|
|
||||||
|
|
||||||
# Build pattern membership map
|
|
||||||
pattern_map = _build_mate_pattern_map(mate_result.patterns)
|
|
||||||
|
|
||||||
# Back-attribute to mates
|
|
||||||
per_mate: list[MateLabel] = []
|
|
||||||
for mate in mates:
|
|
||||||
mate_joint_ids = mate_result.mate_to_joint.get(mate.mate_id, [])
|
|
||||||
|
|
||||||
if not mate_joint_ids:
|
|
||||||
# Mate wasn't converted to any joint (shouldn't happen, but safe)
|
|
||||||
per_mate.append(
|
|
||||||
MateLabel(
|
|
||||||
mate_id=mate.mate_id,
|
|
||||||
is_independent=False,
|
|
||||||
is_redundant=True,
|
|
||||||
issue="unmapped",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Redundant if ALL contributed joints are redundant
|
|
||||||
all_redundant = all(joint_redundant.get(jid, False) for jid in mate_joint_ids)
|
|
||||||
|
|
||||||
# Degenerate if ANY contributed joint is degenerate
|
|
||||||
any_degenerate = any(joint_degenerate.get(jid, False) for jid in mate_joint_ids)
|
|
||||||
|
|
||||||
is_independent = not all_redundant
|
|
||||||
pattern = pattern_map.get(mate.mate_id)
|
|
||||||
|
|
||||||
# Determine issue string
|
|
||||||
issue: str | None = None
|
|
||||||
if all_redundant:
|
|
||||||
issue = "redundant"
|
|
||||||
elif any_degenerate:
|
|
||||||
issue = "degenerate"
|
|
||||||
|
|
||||||
per_mate.append(
|
|
||||||
MateLabel(
|
|
||||||
mate_id=mate.mate_id,
|
|
||||||
is_independent=is_independent,
|
|
||||||
is_redundant=all_redundant,
|
|
||||||
is_degenerate=any_degenerate,
|
|
||||||
pattern=pattern,
|
|
||||||
issue=issue,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Assembly label
|
|
||||||
assert mate_result.labels is not None
|
|
||||||
assembly_label = mate_result.labels.assembly
|
|
||||||
|
|
||||||
return MateAssemblyLabels(
|
|
||||||
per_mate=per_mate,
|
|
||||||
patterns=mate_result.patterns,
|
|
||||||
assembly=assembly_label,
|
|
||||||
analysis=mate_result.labels.analysis,
|
|
||||||
mate_analysis=mate_result,
|
|
||||||
)
|
|
||||||
@@ -1,284 +0,0 @@
|
|||||||
"""Joint pattern recognition from mate combinations.
|
|
||||||
|
|
||||||
Groups mates by body pair and matches them against canonical joint
|
|
||||||
patterns (hinge, slider, ball, etc.). Each pattern is a known
|
|
||||||
combination of mate types that together constrain motion equivalently
|
|
||||||
to a single mechanical joint.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import enum
|
|
||||||
from collections import defaultdict
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from solver.datagen.types import JointType
|
|
||||||
from solver.mates.primitives import GeometryType, Mate, MateType
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"JointPattern",
|
|
||||||
"PatternMatch",
|
|
||||||
"recognize_patterns",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Enums
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class JointPattern(enum.Enum):
|
|
||||||
"""Canonical joint patterns formed by mate combinations."""
|
|
||||||
|
|
||||||
HINGE = "hinge"
|
|
||||||
SLIDER = "slider"
|
|
||||||
CYLINDER = "cylinder"
|
|
||||||
BALL = "ball"
|
|
||||||
PLANAR = "planar"
|
|
||||||
FIXED = "fixed"
|
|
||||||
GEAR = "gear"
|
|
||||||
RACK_PINION = "rack_pinion"
|
|
||||||
UNKNOWN = "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pattern match result
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PatternMatch:
|
|
||||||
"""Result of matching a group of mates to a joint pattern.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
pattern: The identified joint pattern.
|
|
||||||
mates: The mates that form this pattern.
|
|
||||||
body_a: First body in the pair.
|
|
||||||
body_b: Second body in the pair.
|
|
||||||
confidence: How well the mates match the canonical pattern (0-1).
|
|
||||||
equivalent_joint_type: The JointType this pattern maps to.
|
|
||||||
missing_mates: Descriptions of mates absent for a full match.
|
|
||||||
"""
|
|
||||||
|
|
||||||
pattern: JointPattern
|
|
||||||
mates: list[Mate]
|
|
||||||
body_a: int
|
|
||||||
body_b: int
|
|
||||||
confidence: float
|
|
||||||
equivalent_joint_type: JointType
|
|
||||||
missing_mates: list[str] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"pattern": self.pattern.value,
|
|
||||||
"body_a": self.body_a,
|
|
||||||
"body_b": self.body_b,
|
|
||||||
"confidence": self.confidence,
|
|
||||||
"equivalent_joint_type": self.equivalent_joint_type.name,
|
|
||||||
"mate_ids": [m.mate_id for m in self.mates],
|
|
||||||
"missing_mates": self.missing_mates,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Pattern rules (data-driven)
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class _MateRequirement:
|
|
||||||
"""A single mate requirement within a pattern rule."""
|
|
||||||
|
|
||||||
mate_type: MateType
|
|
||||||
geometry_a: GeometryType | None = None
|
|
||||||
geometry_b: GeometryType | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class _PatternRule:
|
|
||||||
"""Defines a canonical pattern as a set of required mates."""
|
|
||||||
|
|
||||||
pattern: JointPattern
|
|
||||||
joint_type: JointType
|
|
||||||
required: tuple[_MateRequirement, ...]
|
|
||||||
description: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
_PATTERN_RULES: list[_PatternRule] = [
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.HINGE,
|
|
||||||
joint_type=JointType.REVOLUTE,
|
|
||||||
required=(
|
|
||||||
_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
description="Concentric axes + coincident plane",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.SLIDER,
|
|
||||||
joint_type=JointType.SLIDER,
|
|
||||||
required=(
|
|
||||||
_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
_MateRequirement(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
description="Coincident plane + parallel axis",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.CYLINDER,
|
|
||||||
joint_type=JointType.CYLINDRICAL,
|
|
||||||
required=(_MateRequirement(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS),),
|
|
||||||
description="Concentric axes only",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.BALL,
|
|
||||||
joint_type=JointType.BALL,
|
|
||||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT),),
|
|
||||||
description="Coincident points",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.PLANAR,
|
|
||||||
joint_type=JointType.PLANAR,
|
|
||||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE),),
|
|
||||||
description="Coincident faces",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.PLANAR,
|
|
||||||
joint_type=JointType.PLANAR,
|
|
||||||
required=(_MateRequirement(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE),),
|
|
||||||
description="Coincident planes (alternate planar)",
|
|
||||||
),
|
|
||||||
_PatternRule(
|
|
||||||
pattern=JointPattern.FIXED,
|
|
||||||
joint_type=JointType.FIXED,
|
|
||||||
required=(_MateRequirement(MateType.LOCK),),
|
|
||||||
description="Lock mate",
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Matching logic
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _mate_matches_requirement(mate: Mate, req: _MateRequirement) -> bool:
|
|
||||||
"""Check if a mate satisfies a requirement."""
|
|
||||||
if mate.mate_type is not req.mate_type:
|
|
||||||
return False
|
|
||||||
if req.geometry_a is not None and mate.ref_a.geometry_type is not req.geometry_a:
|
|
||||||
return False
|
|
||||||
return not (req.geometry_b is not None and mate.ref_b.geometry_type is not req.geometry_b)
|
|
||||||
|
|
||||||
|
|
||||||
def _try_match_rule(
|
|
||||||
rule: _PatternRule,
|
|
||||||
mates: list[Mate],
|
|
||||||
) -> tuple[float, list[Mate], list[str]]:
|
|
||||||
"""Try to match a rule against a group of mates.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
(confidence, matched_mates, missing_descriptions)
|
|
||||||
"""
|
|
||||||
matched: list[Mate] = []
|
|
||||||
missing: list[str] = []
|
|
||||||
|
|
||||||
for req in rule.required:
|
|
||||||
found = False
|
|
||||||
for mate in mates:
|
|
||||||
if mate in matched:
|
|
||||||
continue
|
|
||||||
if _mate_matches_requirement(mate, req):
|
|
||||||
matched.append(mate)
|
|
||||||
found = True
|
|
||||||
break
|
|
||||||
if not found:
|
|
||||||
geom_desc = ""
|
|
||||||
if req.geometry_a is not None:
|
|
||||||
geom_b = req.geometry_b.value if req.geometry_b else "*"
|
|
||||||
geom_desc = f" ({req.geometry_a.value}-{geom_b})"
|
|
||||||
missing.append(f"{req.mate_type.name}{geom_desc}")
|
|
||||||
|
|
||||||
total_required = len(rule.required)
|
|
||||||
if total_required == 0:
|
|
||||||
return 0.0, [], []
|
|
||||||
|
|
||||||
matched_count = len(matched)
|
|
||||||
confidence = matched_count / total_required
|
|
||||||
|
|
||||||
return confidence, matched, missing
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_body_pair(body_a: int, body_b: int) -> tuple[int, int]:
|
|
||||||
"""Normalize a body pair so the smaller ID comes first."""
|
|
||||||
return (min(body_a, body_b), max(body_a, body_b))
|
|
||||||
|
|
||||||
|
|
||||||
def recognize_patterns(mates: list[Mate]) -> list[PatternMatch]:
|
|
||||||
"""Identify joint patterns from a list of mates.
|
|
||||||
|
|
||||||
Groups mates by body pair, then checks each group against
|
|
||||||
canonical pattern rules. Returns matches sorted by confidence
|
|
||||||
descending.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mates: List of mate constraints to analyze.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of PatternMatch results, highest confidence first.
|
|
||||||
"""
|
|
||||||
if not mates:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# Group mates by normalized body pair
|
|
||||||
groups: dict[tuple[int, int], list[Mate]] = defaultdict(list)
|
|
||||||
for mate in mates:
|
|
||||||
pair = _normalize_body_pair(mate.ref_a.body_id, mate.ref_b.body_id)
|
|
||||||
groups[pair].append(mate)
|
|
||||||
|
|
||||||
results: list[PatternMatch] = []
|
|
||||||
|
|
||||||
for (body_a, body_b), group_mates in groups.items():
|
|
||||||
group_matches: list[PatternMatch] = []
|
|
||||||
|
|
||||||
for rule in _PATTERN_RULES:
|
|
||||||
confidence, matched, missing = _try_match_rule(rule, group_mates)
|
|
||||||
|
|
||||||
if confidence > 0:
|
|
||||||
group_matches.append(
|
|
||||||
PatternMatch(
|
|
||||||
pattern=rule.pattern,
|
|
||||||
mates=matched if matched else group_mates,
|
|
||||||
body_a=body_a,
|
|
||||||
body_b=body_b,
|
|
||||||
confidence=confidence,
|
|
||||||
equivalent_joint_type=rule.joint_type,
|
|
||||||
missing_mates=missing,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if group_matches:
|
|
||||||
# Sort by confidence descending, prefer more-specific patterns
|
|
||||||
group_matches.sort(key=lambda m: (-m.confidence, -len(m.mates)))
|
|
||||||
results.extend(group_matches)
|
|
||||||
else:
|
|
||||||
# No pattern matched at all
|
|
||||||
results.append(
|
|
||||||
PatternMatch(
|
|
||||||
pattern=JointPattern.UNKNOWN,
|
|
||||||
mates=group_mates,
|
|
||||||
body_a=body_a,
|
|
||||||
body_b=body_b,
|
|
||||||
confidence=0.0,
|
|
||||||
equivalent_joint_type=JointType.DISTANCE,
|
|
||||||
missing_mates=[],
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Global sort by confidence descending
|
|
||||||
results.sort(key=lambda m: -m.confidence)
|
|
||||||
return results
|
|
||||||
@@ -1,279 +0,0 @@
|
|||||||
"""Mate type definitions and geometry references for assembly constraints.
|
|
||||||
|
|
||||||
Mates are the user-facing constraint primitives in CAD (e.g. SolidWorks-style
|
|
||||||
Coincident, Concentric, Parallel). Each mate references geometry on two bodies
|
|
||||||
and removes a context-dependent number of degrees of freedom.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import enum
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"GeometryRef",
|
|
||||||
"GeometryType",
|
|
||||||
"Mate",
|
|
||||||
"MateType",
|
|
||||||
"dof_removed",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Enums
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class MateType(enum.Enum):
|
|
||||||
"""CAD mate types with default DOF-removal counts.
|
|
||||||
|
|
||||||
Values are ``(ordinal, default_dof)`` tuples so that mate types
|
|
||||||
sharing the same DOF count remain distinct enum members. Use the
|
|
||||||
:attr:`default_dof` property to get the scalar constraint count.
|
|
||||||
|
|
||||||
The actual DOF removed can be context-dependent (e.g. COINCIDENT
|
|
||||||
removes 3 DOF for face-face but only 1 for face-point). Use
|
|
||||||
:func:`dof_removed` for the context-aware count.
|
|
||||||
"""
|
|
||||||
|
|
||||||
COINCIDENT = (0, 3)
|
|
||||||
CONCENTRIC = (1, 2)
|
|
||||||
PARALLEL = (2, 2)
|
|
||||||
PERPENDICULAR = (3, 1)
|
|
||||||
TANGENT = (4, 1)
|
|
||||||
DISTANCE = (5, 1)
|
|
||||||
ANGLE = (6, 1)
|
|
||||||
LOCK = (7, 6)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def default_dof(self) -> int:
|
|
||||||
"""Default number of DOF removed by this mate type."""
|
|
||||||
return self.value[1]
|
|
||||||
|
|
||||||
|
|
||||||
class GeometryType(enum.Enum):
|
|
||||||
"""Types of geometric references used by mates."""
|
|
||||||
|
|
||||||
FACE = "face"
|
|
||||||
EDGE = "edge"
|
|
||||||
POINT = "point"
|
|
||||||
AXIS = "axis"
|
|
||||||
PLANE = "plane"
|
|
||||||
|
|
||||||
|
|
||||||
# Geometry types that require a direction vector.
|
|
||||||
_DIRECTIONAL_TYPES = frozenset(
|
|
||||||
{
|
|
||||||
GeometryType.FACE,
|
|
||||||
GeometryType.AXIS,
|
|
||||||
GeometryType.PLANE,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Dataclasses
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class GeometryRef:
|
|
||||||
"""A reference to a specific geometric entity on a body.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
body_id: Index of the body this geometry belongs to.
|
|
||||||
geometry_type: What kind of geometry (face, edge, etc.).
|
|
||||||
geometry_id: CAD identifier string (e.g. ``"Face001"``).
|
|
||||||
origin: 3D position of the geometry reference point.
|
|
||||||
direction: Unit direction vector. Required for FACE, AXIS, PLANE;
|
|
||||||
``None`` for POINT.
|
|
||||||
"""
|
|
||||||
|
|
||||||
body_id: int
|
|
||||||
geometry_type: GeometryType
|
|
||||||
geometry_id: str
|
|
||||||
origin: np.ndarray = field(default_factory=lambda: np.zeros(3))
|
|
||||||
direction: np.ndarray | None = None
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"body_id": self.body_id,
|
|
||||||
"geometry_type": self.geometry_type.value,
|
|
||||||
"geometry_id": self.geometry_id,
|
|
||||||
"origin": self.origin.tolist(),
|
|
||||||
"direction": self.direction.tolist() if self.direction is not None else None,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: dict[str, Any]) -> GeometryRef:
|
|
||||||
"""Construct from a dict produced by :meth:`to_dict`."""
|
|
||||||
direction_raw = data.get("direction")
|
|
||||||
return cls(
|
|
||||||
body_id=data["body_id"],
|
|
||||||
geometry_type=GeometryType(data["geometry_type"]),
|
|
||||||
geometry_id=data["geometry_id"],
|
|
||||||
origin=np.asarray(data["origin"], dtype=np.float64),
|
|
||||||
direction=(
|
|
||||||
np.asarray(direction_raw, dtype=np.float64) if direction_raw is not None else None
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Mate:
|
|
||||||
"""A mate constraint between geometry on two bodies.
|
|
||||||
|
|
||||||
Attributes:
|
|
||||||
mate_id: Unique identifier for this mate.
|
|
||||||
mate_type: The type of constraint (Coincident, Concentric, etc.).
|
|
||||||
ref_a: Geometry reference on the first body.
|
|
||||||
ref_b: Geometry reference on the second body.
|
|
||||||
value: Scalar parameter for DISTANCE and ANGLE mates (0 otherwise).
|
|
||||||
tolerance: Numeric tolerance for constraint satisfaction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
mate_id: int
|
|
||||||
mate_type: MateType
|
|
||||||
ref_a: GeometryRef
|
|
||||||
ref_b: GeometryRef
|
|
||||||
value: float = 0.0
|
|
||||||
tolerance: float = 1e-6
|
|
||||||
|
|
||||||
def validate(self) -> None:
|
|
||||||
"""Raise ``ValueError`` if this mate has incompatible geometry.
|
|
||||||
|
|
||||||
Checks:
|
|
||||||
- Self-mate (both refs on same body)
|
|
||||||
- CONCENTRIC requires AXIS geometry on both refs
|
|
||||||
- PARALLEL requires directional geometry (not POINT)
|
|
||||||
- TANGENT requires surface geometry (FACE or EDGE)
|
|
||||||
- Directional geometry types must have a direction vector
|
|
||||||
"""
|
|
||||||
if self.ref_a.body_id == self.ref_b.body_id:
|
|
||||||
msg = f"Self-mate: ref_a and ref_b both reference body {self.ref_a.body_id}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
|
||||||
if ref.geometry_type in _DIRECTIONAL_TYPES and ref.direction is None:
|
|
||||||
msg = (
|
|
||||||
f"{label}: geometry type {ref.geometry_type.value} requires a direction vector"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if self.mate_type is MateType.CONCENTRIC:
|
|
||||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
|
||||||
if ref.geometry_type is not GeometryType.AXIS:
|
|
||||||
msg = (
|
|
||||||
f"CONCENTRIC mate requires AXIS geometry, "
|
|
||||||
f"got {ref.geometry_type.value} on {label}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if self.mate_type is MateType.PARALLEL:
|
|
||||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
|
||||||
if ref.geometry_type is GeometryType.POINT:
|
|
||||||
msg = f"PARALLEL mate requires directional geometry, got POINT on {label}"
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
if self.mate_type is MateType.TANGENT:
|
|
||||||
_surface = frozenset({GeometryType.FACE, GeometryType.EDGE})
|
|
||||||
for label, ref in [("ref_a", self.ref_a), ("ref_b", self.ref_b)]:
|
|
||||||
if ref.geometry_type not in _surface:
|
|
||||||
msg = (
|
|
||||||
f"TANGENT mate requires surface geometry "
|
|
||||||
f"(FACE or EDGE), got {ref.geometry_type.value} "
|
|
||||||
f"on {label}"
|
|
||||||
)
|
|
||||||
raise ValueError(msg)
|
|
||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
|
||||||
"""Return a JSON-serializable dict."""
|
|
||||||
return {
|
|
||||||
"mate_id": self.mate_id,
|
|
||||||
"mate_type": self.mate_type.name,
|
|
||||||
"ref_a": self.ref_a.to_dict(),
|
|
||||||
"ref_b": self.ref_b.to_dict(),
|
|
||||||
"value": self.value,
|
|
||||||
"tolerance": self.tolerance,
|
|
||||||
}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: dict[str, Any]) -> Mate:
|
|
||||||
"""Construct from a dict produced by :meth:`to_dict`."""
|
|
||||||
return cls(
|
|
||||||
mate_id=data["mate_id"],
|
|
||||||
mate_type=MateType[data["mate_type"]],
|
|
||||||
ref_a=GeometryRef.from_dict(data["ref_a"]),
|
|
||||||
ref_b=GeometryRef.from_dict(data["ref_b"]),
|
|
||||||
value=data.get("value", 0.0),
|
|
||||||
tolerance=data.get("tolerance", 1e-6),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Context-dependent DOF removal
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
# Lookup table: (MateType, ref_a GeometryType, ref_b GeometryType) -> DOF removed.
|
|
||||||
# Entries with None match any geometry type for that position.
|
|
||||||
_DOF_TABLE: dict[tuple[MateType, GeometryType | None, GeometryType | None], int] = {
|
|
||||||
# COINCIDENT — context-dependent
|
|
||||||
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.FACE): 3,
|
|
||||||
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.POINT): 3,
|
|
||||||
(MateType.COINCIDENT, GeometryType.PLANE, GeometryType.PLANE): 3,
|
|
||||||
(MateType.COINCIDENT, GeometryType.EDGE, GeometryType.EDGE): 2,
|
|
||||||
(MateType.COINCIDENT, GeometryType.FACE, GeometryType.POINT): 1,
|
|
||||||
(MateType.COINCIDENT, GeometryType.POINT, GeometryType.FACE): 1,
|
|
||||||
# CONCENTRIC
|
|
||||||
(MateType.CONCENTRIC, GeometryType.AXIS, GeometryType.AXIS): 2,
|
|
||||||
# PARALLEL
|
|
||||||
(MateType.PARALLEL, GeometryType.AXIS, GeometryType.AXIS): 2,
|
|
||||||
(MateType.PARALLEL, GeometryType.FACE, GeometryType.FACE): 2,
|
|
||||||
(MateType.PARALLEL, GeometryType.PLANE, GeometryType.PLANE): 2,
|
|
||||||
# TANGENT
|
|
||||||
(MateType.TANGENT, GeometryType.FACE, GeometryType.FACE): 1,
|
|
||||||
(MateType.TANGENT, GeometryType.FACE, GeometryType.EDGE): 1,
|
|
||||||
(MateType.TANGENT, GeometryType.EDGE, GeometryType.FACE): 1,
|
|
||||||
# Types where DOF is always the same regardless of geometry
|
|
||||||
(MateType.PERPENDICULAR, None, None): 1,
|
|
||||||
(MateType.DISTANCE, None, None): 1,
|
|
||||||
(MateType.ANGLE, None, None): 1,
|
|
||||||
(MateType.LOCK, None, None): 6,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def dof_removed(
|
|
||||||
mate_type: MateType,
|
|
||||||
ref_a: GeometryRef,
|
|
||||||
ref_b: GeometryRef,
|
|
||||||
) -> int:
|
|
||||||
"""Return the number of DOF removed by a mate given its geometry context.
|
|
||||||
|
|
||||||
Looks up the exact ``(mate_type, ref_a.geometry_type, ref_b.geometry_type)``
|
|
||||||
combination first, then falls back to a wildcard ``(mate_type, None, None)``
|
|
||||||
entry, and finally to :attr:`MateType.default_dof`.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
mate_type: The mate constraint type.
|
|
||||||
ref_a: Geometry reference on the first body.
|
|
||||||
ref_b: Geometry reference on the second body.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of scalar DOF removed by this mate.
|
|
||||||
"""
|
|
||||||
key = (mate_type, ref_a.geometry_type, ref_b.geometry_type)
|
|
||||||
if key in _DOF_TABLE:
|
|
||||||
return _DOF_TABLE[key]
|
|
||||||
|
|
||||||
wildcard = (mate_type, None, None)
|
|
||||||
if wildcard in _DOF_TABLE:
|
|
||||||
return _DOF_TABLE[wildcard]
|
|
||||||
|
|
||||||
return mate_type.default_dof
|
|
||||||
@@ -0,0 +1,31 @@
|
|||||||
|
"""GNN models for assembly constraint analysis."""
|
||||||
|
|
||||||
|
from solver.models.assembly_gnn import AssemblyGNN
|
||||||
|
from solver.models.encoder import GATEncoder, GINEncoder
|
||||||
|
from solver.models.factory import build_loss, build_model
|
||||||
|
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
|
||||||
|
from solver.models.heads import (
|
||||||
|
DOFRegressionHead,
|
||||||
|
DOFTrackingHead,
|
||||||
|
EdgeClassificationHead,
|
||||||
|
GraphClassificationHead,
|
||||||
|
JointTypeHead,
|
||||||
|
)
|
||||||
|
from solver.models.losses import MultiTaskLoss
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ASSEMBLY_CLASSES",
|
||||||
|
"AssemblyGNN",
|
||||||
|
"DOFRegressionHead",
|
||||||
|
"DOFTrackingHead",
|
||||||
|
"EdgeClassificationHead",
|
||||||
|
"GATEncoder",
|
||||||
|
"GINEncoder",
|
||||||
|
"GraphClassificationHead",
|
||||||
|
"JOINT_TYPE_NAMES",
|
||||||
|
"JointTypeHead",
|
||||||
|
"MultiTaskLoss",
|
||||||
|
"assembly_to_pyg",
|
||||||
|
"build_loss",
|
||||||
|
"build_model",
|
||||||
|
]
|
||||||
|
|||||||
131
solver/models/assembly_gnn.py
Normal file
131
solver/models/assembly_gnn.py
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
"""AssemblyGNN -- main model wiring encoder and task heads."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
from solver.models.encoder import GATEncoder, GINEncoder
|
||||||
|
from solver.models.heads import (
|
||||||
|
DOFRegressionHead,
|
||||||
|
DOFTrackingHead,
|
||||||
|
EdgeClassificationHead,
|
||||||
|
GraphClassificationHead,
|
||||||
|
JointTypeHead,
|
||||||
|
)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
__all__ = ["AssemblyGNN"]
|
||||||
|
|
||||||
|
_ENCODERS = {
|
||||||
|
"gin": GINEncoder,
|
||||||
|
"gat": GATEncoder,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class AssemblyGNN(nn.Module):
|
||||||
|
"""Multi-task GNN for assembly constraint analysis.
|
||||||
|
|
||||||
|
Wires an encoder (GIN or GAT) with optional task-specific prediction
|
||||||
|
heads for edge classification, graph classification, joint type
|
||||||
|
prediction, DOF regression, and per-body DOF tracking.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
encoder_type: ``"gin"`` or ``"gat"``.
|
||||||
|
encoder_config: Kwargs passed to the encoder constructor.
|
||||||
|
heads_config: Dict of head name → config dict. Each entry must have
|
||||||
|
an ``enabled`` bool. Additional keys are passed as kwargs.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
encoder_type: str = "gin",
|
||||||
|
encoder_config: dict[str, Any] | None = None,
|
||||||
|
heads_config: dict[str, dict[str, Any]] | None = None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
encoder_config = encoder_config or {}
|
||||||
|
heads_config = heads_config or {}
|
||||||
|
|
||||||
|
if encoder_type not in _ENCODERS:
|
||||||
|
msg = f"Unknown encoder type: {encoder_type!r}. Choose from {list(_ENCODERS)}"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
self.encoder = _ENCODERS[encoder_type](**encoder_config)
|
||||||
|
hidden_dim = self.encoder.hidden_dim
|
||||||
|
|
||||||
|
self.heads = nn.ModuleDict()
|
||||||
|
self._build_heads(heads_config, hidden_dim)
|
||||||
|
|
||||||
|
def _build_heads(
|
||||||
|
self,
|
||||||
|
heads_config: dict[str, dict[str, Any]],
|
||||||
|
hidden_dim: int,
|
||||||
|
) -> None:
|
||||||
|
"""Instantiate enabled heads."""
|
||||||
|
cfg = heads_config.get("edge_classification", {})
|
||||||
|
if cfg.get("enabled", False):
|
||||||
|
self.heads["edge_pred"] = EdgeClassificationHead(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
inner_dim=cfg.get("hidden_dim", 64),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = heads_config.get("graph_classification", {})
|
||||||
|
if cfg.get("enabled", False):
|
||||||
|
self.heads["graph_pred"] = GraphClassificationHead(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
num_classes=cfg.get("num_classes", 4),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = heads_config.get("joint_type", {})
|
||||||
|
if cfg.get("enabled", False):
|
||||||
|
self.heads["joint_type_pred"] = JointTypeHead(
|
||||||
|
hidden_dim=hidden_dim,
|
||||||
|
num_classes=cfg.get("num_classes", 11),
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = heads_config.get("dof_regression", {})
|
||||||
|
if cfg.get("enabled", False):
|
||||||
|
self.heads["dof_pred"] = DOFRegressionHead(hidden_dim=hidden_dim)
|
||||||
|
|
||||||
|
cfg = heads_config.get("dof_tracking", {})
|
||||||
|
if cfg.get("enabled", False):
|
||||||
|
self.heads["body_dof_pred"] = DOFTrackingHead(hidden_dim=hidden_dim)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
edge_attr: torch.Tensor,
|
||||||
|
batch: torch.Tensor | None = None,
|
||||||
|
) -> dict[str, torch.Tensor]:
|
||||||
|
"""Run encoder and all enabled heads.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict with keys matching enabled head names:
|
||||||
|
``edge_pred``, ``graph_pred``, ``joint_type_pred``,
|
||||||
|
``dof_pred``, ``body_dof_pred``.
|
||||||
|
"""
|
||||||
|
node_emb, edge_emb, graph_emb = self.encoder(x, edge_index, edge_attr, batch)
|
||||||
|
|
||||||
|
preds: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Route embeddings to the appropriate heads.
|
||||||
|
_edge_heads = {"edge_pred", "joint_type_pred"}
|
||||||
|
_graph_heads = {"graph_pred", "dof_pred"}
|
||||||
|
_node_heads = {"body_dof_pred"}
|
||||||
|
|
||||||
|
for name, head in self.heads.items():
|
||||||
|
if name in _edge_heads:
|
||||||
|
preds[name] = head(edge_emb)
|
||||||
|
elif name in _graph_heads:
|
||||||
|
preds[name] = head(graph_emb)
|
||||||
|
elif name in _node_heads:
|
||||||
|
preds[name] = head(node_emb)
|
||||||
|
|
||||||
|
return preds
|
||||||
194
solver/models/encoder.py
Normal file
194
solver/models/encoder.py
Normal file
@@ -0,0 +1,194 @@
|
|||||||
|
"""GIN and GAT graph neural network encoders."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from torch_geometric.nn import GATv2Conv, GINEConv, global_mean_pool
|
||||||
|
|
||||||
|
__all__ = ["GATEncoder", "GINEncoder"]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gin_mlp(in_dim: int, hidden_dim: int) -> nn.Sequential:
|
||||||
|
"""Two-layer MLP used inside GINEConv."""
|
||||||
|
return nn.Sequential(
|
||||||
|
nn.Linear(in_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GINEncoder(nn.Module):
|
||||||
|
"""Graph Isomorphism Network encoder with edge features (GINE).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_features_dim: Input node feature dimension.
|
||||||
|
edge_features_dim: Input edge feature dimension.
|
||||||
|
hidden_dim: Hidden dimension for all layers.
|
||||||
|
num_layers: Number of GINEConv layers.
|
||||||
|
dropout: Dropout probability.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_features_dim: int = 22,
|
||||||
|
edge_features_dim: int = 22,
|
||||||
|
hidden_dim: int = 128,
|
||||||
|
num_layers: int = 3,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self._hidden_dim = hidden_dim
|
||||||
|
|
||||||
|
self.node_proj = nn.Sequential(
|
||||||
|
nn.Linear(node_features_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.edge_proj = nn.Sequential(
|
||||||
|
nn.Linear(edge_features_dim, hidden_dim),
|
||||||
|
nn.BatchNorm1d(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
self.norms = nn.ModuleList()
|
||||||
|
for _ in range(num_layers):
|
||||||
|
conv = GINEConv(nn=_make_gin_mlp(hidden_dim, hidden_dim), edge_dim=hidden_dim)
|
||||||
|
self.convs.append(conv)
|
||||||
|
self.norms.append(nn.BatchNorm1d(hidden_dim))
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# Edge embedding from endpoint + edge features.
|
||||||
|
self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_dim(self) -> int:
|
||||||
|
return self._hidden_dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
edge_attr: torch.Tensor,
|
||||||
|
batch: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode graph and return node, edge, and graph embeddings.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
node_emb: [N, hidden_dim]
|
||||||
|
edge_emb: [E, hidden_dim]
|
||||||
|
graph_emb: [B, hidden_dim]
|
||||||
|
"""
|
||||||
|
h = self.node_proj(x)
|
||||||
|
e = self.edge_proj(edge_attr)
|
||||||
|
|
||||||
|
for conv, norm in zip(self.convs, self.norms):
|
||||||
|
h = conv(h, edge_index, e)
|
||||||
|
h = norm(h)
|
||||||
|
h = torch.relu(h)
|
||||||
|
h = self.dropout(h)
|
||||||
|
|
||||||
|
# Edge embeddings from endpoint concatenation.
|
||||||
|
src, dst = edge_index
|
||||||
|
edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1))
|
||||||
|
|
||||||
|
# Graph embedding via mean pooling.
|
||||||
|
graph_emb = global_mean_pool(h, batch)
|
||||||
|
|
||||||
|
return h, edge_emb, graph_emb
|
||||||
|
|
||||||
|
|
||||||
|
class GATEncoder(nn.Module):
|
||||||
|
"""Graph Attention Network v2 encoder with edge features and residuals.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
node_features_dim: Input node feature dimension.
|
||||||
|
edge_features_dim: Input edge feature dimension.
|
||||||
|
hidden_dim: Hidden dimension (must be divisible by num_heads).
|
||||||
|
num_layers: Number of GATv2Conv layers.
|
||||||
|
num_heads: Number of attention heads.
|
||||||
|
dropout: Dropout probability.
|
||||||
|
residual: Use residual connections.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
node_features_dim: int = 22,
|
||||||
|
edge_features_dim: int = 22,
|
||||||
|
hidden_dim: int = 256,
|
||||||
|
num_layers: int = 4,
|
||||||
|
num_heads: int = 8,
|
||||||
|
dropout: float = 0.1,
|
||||||
|
residual: bool = True,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
if hidden_dim % num_heads != 0:
|
||||||
|
msg = f"hidden_dim ({hidden_dim}) must be divisible by num_heads ({num_heads})"
|
||||||
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
self._hidden_dim = hidden_dim
|
||||||
|
self.residual = residual
|
||||||
|
head_dim = hidden_dim // num_heads
|
||||||
|
|
||||||
|
self.node_proj = nn.Sequential(
|
||||||
|
nn.Linear(node_features_dim, hidden_dim),
|
||||||
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.edge_proj = nn.Sequential(
|
||||||
|
nn.Linear(edge_features_dim, hidden_dim),
|
||||||
|
nn.LayerNorm(hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.convs = nn.ModuleList()
|
||||||
|
self.norms = nn.ModuleList()
|
||||||
|
for _ in range(num_layers):
|
||||||
|
conv = GATv2Conv(
|
||||||
|
in_channels=hidden_dim,
|
||||||
|
out_channels=head_dim,
|
||||||
|
heads=num_heads,
|
||||||
|
edge_dim=hidden_dim,
|
||||||
|
concat=True,
|
||||||
|
)
|
||||||
|
self.convs.append(conv)
|
||||||
|
self.norms.append(nn.LayerNorm(hidden_dim))
|
||||||
|
|
||||||
|
self.dropout = nn.Dropout(dropout)
|
||||||
|
|
||||||
|
# Edge embedding from endpoint + edge features.
|
||||||
|
self.edge_mlp = nn.Linear(hidden_dim * 3, hidden_dim)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hidden_dim(self) -> int:
|
||||||
|
return self._hidden_dim
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
edge_index: torch.Tensor,
|
||||||
|
edge_attr: torch.Tensor,
|
||||||
|
batch: torch.Tensor | None = None,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode graph and return node, edge, and graph embeddings."""
|
||||||
|
h = self.node_proj(x)
|
||||||
|
e = self.edge_proj(edge_attr)
|
||||||
|
|
||||||
|
for conv, norm in zip(self.convs, self.norms):
|
||||||
|
h_new = conv(h, edge_index, e)
|
||||||
|
h_new = norm(h_new)
|
||||||
|
h_new = torch.relu(h_new)
|
||||||
|
h_new = self.dropout(h_new)
|
||||||
|
if self.residual:
|
||||||
|
h = h + h_new
|
||||||
|
else:
|
||||||
|
h = h_new
|
||||||
|
|
||||||
|
src, dst = edge_index
|
||||||
|
edge_emb = self.edge_mlp(torch.cat([h[src], h[dst], e], dim=1))
|
||||||
|
graph_emb = global_mean_pool(h, batch)
|
||||||
|
|
||||||
|
return h, edge_emb, graph_emb
|
||||||
82
solver/models/factory.py
Normal file
82
solver/models/factory.py
Normal file
@@ -0,0 +1,82 @@
|
|||||||
|
"""Factory functions to build model and loss from config dicts."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from solver.models.assembly_gnn import AssemblyGNN
|
||||||
|
from solver.models.losses import MultiTaskLoss
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = ["build_loss", "build_model"]
|
||||||
|
|
||||||
|
|
||||||
|
def build_model(config: dict[str, Any]) -> AssemblyGNN:
|
||||||
|
"""Construct an AssemblyGNN from a parsed YAML model config.
|
||||||
|
|
||||||
|
Expected config structure (matches ``configs/model/*.yaml``)::
|
||||||
|
|
||||||
|
architecture: gin # or gat
|
||||||
|
encoder:
|
||||||
|
hidden_dim: 128
|
||||||
|
num_layers: 3
|
||||||
|
...
|
||||||
|
node_features_dim: 22
|
||||||
|
edge_features_dim: 22
|
||||||
|
heads:
|
||||||
|
edge_classification:
|
||||||
|
enabled: true
|
||||||
|
...
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Parsed YAML model config dict.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ``AssemblyGNN`` instance.
|
||||||
|
"""
|
||||||
|
encoder_type = config.get("architecture", "gin")
|
||||||
|
|
||||||
|
encoder_config: dict[str, Any] = dict(config.get("encoder", {}))
|
||||||
|
encoder_config.setdefault("node_features_dim", config.get("node_features_dim", 22))
|
||||||
|
encoder_config.setdefault("edge_features_dim", config.get("edge_features_dim", 22))
|
||||||
|
|
||||||
|
heads_config = config.get("heads", {})
|
||||||
|
|
||||||
|
return AssemblyGNN(
|
||||||
|
encoder_type=encoder_type,
|
||||||
|
encoder_config=encoder_config,
|
||||||
|
heads_config=heads_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_loss(config: dict[str, Any]) -> MultiTaskLoss:
|
||||||
|
"""Construct a MultiTaskLoss from a parsed YAML training config.
|
||||||
|
|
||||||
|
Expected config structure (from ``configs/training/*.yaml`` ``loss`` section)::
|
||||||
|
|
||||||
|
loss:
|
||||||
|
edge_weight: 1.0
|
||||||
|
graph_weight: 0.5
|
||||||
|
joint_type_weight: 0.3
|
||||||
|
dof_weight: 0.2
|
||||||
|
body_dof_weight: 0.2
|
||||||
|
redundant_penalty: 2.0
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Parsed YAML training config dict (full config, not just loss section).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Configured ``MultiTaskLoss`` instance.
|
||||||
|
"""
|
||||||
|
loss_config = config.get("loss", {})
|
||||||
|
|
||||||
|
return MultiTaskLoss(
|
||||||
|
edge_weight=loss_config.get("edge_weight", 1.0),
|
||||||
|
graph_weight=loss_config.get("graph_weight", 0.5),
|
||||||
|
joint_type_weight=loss_config.get("joint_type_weight", 0.3),
|
||||||
|
dof_weight=loss_config.get("dof_weight", 0.2),
|
||||||
|
body_dof_weight=loss_config.get("body_dof_weight", 0.2),
|
||||||
|
redundant_penalty=loss_config.get("redundant_penalty", 2.0),
|
||||||
|
)
|
||||||
260
solver/models/graph_conv.py
Normal file
260
solver/models/graph_conv.py
Normal file
@@ -0,0 +1,260 @@
|
|||||||
|
"""Convert datagen assembly dicts to PyTorch Geometric Data objects."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch_geometric.data import Data
|
||||||
|
|
||||||
|
from solver.datagen.types import JointType
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ASSEMBLY_CLASSES",
|
||||||
|
"JOINT_TYPE_NAMES",
|
||||||
|
"assembly_to_pyg",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Ordered list matching JointType ordinal values (0-10).
|
||||||
|
JOINT_TYPE_NAMES: list[str] = [jt.name for jt in JointType]
|
||||||
|
|
||||||
|
# Assembly classification label mapping.
|
||||||
|
ASSEMBLY_CLASSES: dict[str, int] = {
|
||||||
|
"well-constrained": 0,
|
||||||
|
"underconstrained": 1,
|
||||||
|
"overconstrained": 2,
|
||||||
|
"mixed": 3,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Joint type name -> ordinal for fast lookup.
|
||||||
|
_JOINT_TYPE_TO_ORD: dict[str, int] = {jt.name: jt.value[0] for jt in JointType}
|
||||||
|
|
||||||
|
# Joint type name -> DOF removed.
|
||||||
|
_JOINT_TYPE_TO_DOF: dict[str, int] = {jt.name: jt.dof for jt in JointType}
|
||||||
|
|
||||||
|
_NUM_JOINT_TYPES = len(JointType)
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_node_features(
|
||||||
|
body_positions: list[list[float]],
|
||||||
|
body_orientations: list[list[list[float]]],
|
||||||
|
joints: list[dict[str, Any]],
|
||||||
|
grounded: bool,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Encode node features as a [N, 22] tensor.
|
||||||
|
|
||||||
|
Dims 0-2: position (centered per graph)
|
||||||
|
Dims 3-11: flattened 3x3 rotation matrix
|
||||||
|
Dim 12: is grounded flag
|
||||||
|
Dim 13: node degree / 10
|
||||||
|
Dims 14-19: degree bucket one-hot (0, 1, 2, 3, 4, 5+)
|
||||||
|
Dim 20: total incident DOF removed / 30
|
||||||
|
Dim 21: fraction of incident joints that are FIXED
|
||||||
|
"""
|
||||||
|
n_bodies = len(body_positions)
|
||||||
|
|
||||||
|
# Positions centered per graph.
|
||||||
|
pos = torch.tensor(body_positions, dtype=torch.float32)
|
||||||
|
centroid = pos.mean(dim=0, keepdim=True)
|
||||||
|
pos = pos - centroid
|
||||||
|
|
||||||
|
# Flattened orientation matrices [N, 9].
|
||||||
|
orient = torch.tensor(body_orientations, dtype=torch.float32).reshape(n_bodies, 9)
|
||||||
|
|
||||||
|
# Compute per-node degree and incident joint stats.
|
||||||
|
degree = torch.zeros(n_bodies, dtype=torch.float32)
|
||||||
|
dof_removed = torch.zeros(n_bodies, dtype=torch.float32)
|
||||||
|
fixed_count = torch.zeros(n_bodies, dtype=torch.float32)
|
||||||
|
|
||||||
|
for j in joints:
|
||||||
|
a, b = j["body_a"], j["body_b"]
|
||||||
|
jtype = j["type"]
|
||||||
|
dof = _JOINT_TYPE_TO_DOF.get(jtype, 0)
|
||||||
|
degree[a] += 1
|
||||||
|
degree[b] += 1
|
||||||
|
dof_removed[a] += dof
|
||||||
|
dof_removed[b] += dof
|
||||||
|
if jtype == "FIXED":
|
||||||
|
fixed_count[a] += 1
|
||||||
|
fixed_count[b] += 1
|
||||||
|
|
||||||
|
# Grounded flag: body 0 if assembly is grounded.
|
||||||
|
grounded_flag = torch.zeros(n_bodies, 1, dtype=torch.float32)
|
||||||
|
if grounded and n_bodies > 0:
|
||||||
|
grounded_flag[0, 0] = 1.0
|
||||||
|
|
||||||
|
# Degree normalized.
|
||||||
|
degree_norm = (degree / 10.0).unsqueeze(1)
|
||||||
|
|
||||||
|
# Degree bucket one-hot [N, 6]: buckets 0, 1, 2, 3, 4, 5+.
|
||||||
|
bucket = degree.clamp(max=5).long()
|
||||||
|
degree_onehot = torch.zeros(n_bodies, 6, dtype=torch.float32)
|
||||||
|
degree_onehot.scatter_(1, bucket.unsqueeze(1), 1.0)
|
||||||
|
|
||||||
|
# Total incident DOF removed, normalized.
|
||||||
|
dof_norm = (dof_removed / 30.0).unsqueeze(1)
|
||||||
|
|
||||||
|
# Fraction of incident joints that are FIXED.
|
||||||
|
safe_degree = degree.clamp(min=1)
|
||||||
|
fixed_frac = (fixed_count / safe_degree).unsqueeze(1)
|
||||||
|
|
||||||
|
# Concatenate: [N, 3+9+1+1+6+1+1] = [N, 22].
|
||||||
|
x = torch.cat(
|
||||||
|
[pos, orient, grounded_flag, degree_norm, degree_onehot, dof_norm, fixed_frac], dim=1
|
||||||
|
)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def _encode_edge_features(
|
||||||
|
joints: list[dict[str, Any]],
|
||||||
|
body_positions: list[list[float]],
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""Encode edge features and build bidirectional edge_index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
edge_index: [2, 2*n_joints] (each joint as two directed edges).
|
||||||
|
edge_attr: [2*n_joints, 22] edge features.
|
||||||
|
"""
|
||||||
|
n_joints = len(joints)
|
||||||
|
if n_joints == 0:
|
||||||
|
return (
|
||||||
|
torch.zeros(2, 0, dtype=torch.long),
|
||||||
|
torch.zeros(0, 22, dtype=torch.float32),
|
||||||
|
)
|
||||||
|
|
||||||
|
pos = body_positions
|
||||||
|
src_list: list[int] = []
|
||||||
|
dst_list: list[int] = []
|
||||||
|
features: list[list[float]] = []
|
||||||
|
|
||||||
|
for j in joints:
|
||||||
|
a, b = j["body_a"], j["body_b"]
|
||||||
|
jtype_name = j["type"]
|
||||||
|
ordinal = _JOINT_TYPE_TO_ORD.get(jtype_name, 0)
|
||||||
|
dof = _JOINT_TYPE_TO_DOF.get(jtype_name, 0)
|
||||||
|
|
||||||
|
# One-hot joint type [11].
|
||||||
|
onehot = [0.0] * _NUM_JOINT_TYPES
|
||||||
|
onehot[ordinal] = 1.0
|
||||||
|
|
||||||
|
# Axis [3].
|
||||||
|
axis = j.get("axis", [0.0, 0.0, 1.0])
|
||||||
|
|
||||||
|
# Anchor offsets relative to body positions (fallback to zeros).
|
||||||
|
anchor_a_raw = j.get("anchor_a")
|
||||||
|
anchor_b_raw = j.get("anchor_b")
|
||||||
|
if anchor_a_raw is not None:
|
||||||
|
anchor_a_off = [anchor_a_raw[k] - pos[a][k] for k in range(3)]
|
||||||
|
else:
|
||||||
|
anchor_a_off = [0.0, 0.0, 0.0]
|
||||||
|
if anchor_b_raw is not None:
|
||||||
|
anchor_b_off = [anchor_b_raw[k] - pos[b][k] for k in range(3)]
|
||||||
|
else:
|
||||||
|
anchor_b_off = [0.0, 0.0, 0.0]
|
||||||
|
|
||||||
|
pitch = j.get("pitch", 0.0)
|
||||||
|
dof_norm = dof / 6.0
|
||||||
|
|
||||||
|
feat = onehot + axis + anchor_a_off + anchor_b_off + [pitch, dof_norm]
|
||||||
|
|
||||||
|
# Bidirectional: a->b and b->a with identical features.
|
||||||
|
src_list.extend([a, b])
|
||||||
|
dst_list.extend([b, a])
|
||||||
|
features.append(feat)
|
||||||
|
features.append(feat)
|
||||||
|
|
||||||
|
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
|
||||||
|
edge_attr = torch.tensor(features, dtype=torch.float32)
|
||||||
|
return edge_index, edge_attr
|
||||||
|
|
||||||
|
|
||||||
|
def assembly_to_pyg(
|
||||||
|
example: dict[str, Any],
|
||||||
|
*,
|
||||||
|
include_labels: bool = True,
|
||||||
|
) -> Data:
|
||||||
|
"""Convert a datagen training example dict to a PyG Data object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
example: Dict from ``SyntheticAssemblyGenerator.generate_training_batch()``.
|
||||||
|
include_labels: Attach ground truth labels to the Data object.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
``torch_geometric.data.Data`` with node features ``x``, ``edge_index``,
|
||||||
|
``edge_attr``, and optionally label tensors ``y_edge``, ``y_graph``,
|
||||||
|
``y_joint_type``, ``y_dof``, ``y_body_dof``.
|
||||||
|
"""
|
||||||
|
body_positions = example["body_positions"]
|
||||||
|
body_orientations = example["body_orientations"]
|
||||||
|
joints = example["joints"]
|
||||||
|
grounded = example.get("grounded", False)
|
||||||
|
|
||||||
|
x = _encode_node_features(body_positions, body_orientations, joints, grounded)
|
||||||
|
edge_index, edge_attr = _encode_edge_features(joints, body_positions)
|
||||||
|
|
||||||
|
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
||||||
|
data.num_nodes = x.size(0)
|
||||||
|
|
||||||
|
if include_labels:
|
||||||
|
_attach_labels(data, example, joints)
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def _attach_labels(
|
||||||
|
data: Data,
|
||||||
|
example: dict[str, Any],
|
||||||
|
joints: list[dict[str, Any]],
|
||||||
|
) -> None:
|
||||||
|
"""Attach ground truth label tensors to a Data object."""
|
||||||
|
joint_labels = example.get("joint_labels", {})
|
||||||
|
labels = example.get("labels", {})
|
||||||
|
|
||||||
|
# Per-edge: binary independent (1) / redundant (0).
|
||||||
|
# Duplicated for bidirectional edges.
|
||||||
|
n_joints = len(joints)
|
||||||
|
edge_labels: list[float] = []
|
||||||
|
joint_type_labels: list[int] = []
|
||||||
|
for j in joints:
|
||||||
|
jid = j["joint_id"]
|
||||||
|
jl = joint_labels.get(jid) or joint_labels.get(str(jid), {})
|
||||||
|
is_independent = 1.0 if jl.get("redundant_constraints", 0) == 0 else 0.0
|
||||||
|
ordinal = _JOINT_TYPE_TO_ORD.get(j["type"], 0)
|
||||||
|
# Bidirectional: duplicate.
|
||||||
|
edge_labels.extend([is_independent, is_independent])
|
||||||
|
joint_type_labels.extend([ordinal, ordinal])
|
||||||
|
|
||||||
|
if n_joints > 0:
|
||||||
|
data.y_edge = torch.tensor(edge_labels, dtype=torch.float32)
|
||||||
|
data.y_joint_type = torch.tensor(joint_type_labels, dtype=torch.long)
|
||||||
|
else:
|
||||||
|
data.y_edge = torch.zeros(0, dtype=torch.float32)
|
||||||
|
data.y_joint_type = torch.zeros(0, dtype=torch.long)
|
||||||
|
|
||||||
|
# Assembly classification.
|
||||||
|
classification = example.get("assembly_classification", "")
|
||||||
|
data.y_graph = torch.tensor(
|
||||||
|
[ASSEMBLY_CLASSES.get(classification, 0)],
|
||||||
|
dtype=torch.long,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Total DOF.
|
||||||
|
assembly_labels = labels.get("assembly", {})
|
||||||
|
data.y_dof = torch.tensor(
|
||||||
|
[float(assembly_labels.get("total_dof", 0))],
|
||||||
|
dtype=torch.float32,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Per-body DOF: [N, 2] (translational, rotational).
|
||||||
|
per_body = labels.get("per_body", [])
|
||||||
|
n_bodies = data.num_nodes
|
||||||
|
body_dof = torch.zeros(n_bodies, 2, dtype=torch.float32)
|
||||||
|
for entry in per_body:
|
||||||
|
bid = entry["body_id"]
|
||||||
|
if 0 <= bid < n_bodies:
|
||||||
|
body_dof[bid, 0] = float(entry["translational_dof"])
|
||||||
|
body_dof[bid, 1] = float(entry["rotational_dof"])
|
||||||
|
data.y_body_dof = body_dof
|
||||||
119
solver/models/heads.py
Normal file
119
solver/models/heads.py
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
"""Task-specific prediction heads for assembly GNN."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"DOFRegressionHead",
|
||||||
|
"DOFTrackingHead",
|
||||||
|
"EdgeClassificationHead",
|
||||||
|
"GraphClassificationHead",
|
||||||
|
"JointTypeHead",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class EdgeClassificationHead(nn.Module):
|
||||||
|
"""Binary edge classification (independent vs redundant).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_dim: Input embedding dimension.
|
||||||
|
inner_dim: Internal MLP hidden dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int = 128, inner_dim: int = 64) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, inner_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(inner_dim, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return logits [E, 1]."""
|
||||||
|
return self.mlp(edge_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class GraphClassificationHead(nn.Module):
|
||||||
|
"""Assembly classification (well/under/over-constrained/mixed).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_dim: Input embedding dimension.
|
||||||
|
num_classes: Number of classification categories.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int = 128, num_classes: int = 4) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return logits [B, num_classes]."""
|
||||||
|
return self.mlp(graph_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class JointTypeHead(nn.Module):
|
||||||
|
"""Joint type classification from edge embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_dim: Input embedding dimension.
|
||||||
|
num_classes: Number of joint types.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int = 128, num_classes: int = 11) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim // 2, num_classes),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, edge_emb: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return logits [E, num_classes]."""
|
||||||
|
return self.mlp(edge_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class DOFRegressionHead(nn.Module):
|
||||||
|
"""Total DOF regression from graph embedding.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_dim: Input embedding dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int = 128) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim // 2, 1),
|
||||||
|
nn.Softplus(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, graph_emb: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return non-negative DOF prediction [B, 1]."""
|
||||||
|
return self.mlp(graph_emb)
|
||||||
|
|
||||||
|
|
||||||
|
class DOFTrackingHead(nn.Module):
|
||||||
|
"""Per-body DOF prediction (translational, rotational) from node embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hidden_dim: Input embedding dimension.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, hidden_dim: int = 128) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.mlp = nn.Sequential(
|
||||||
|
nn.Linear(hidden_dim, hidden_dim // 2),
|
||||||
|
nn.ReLU(),
|
||||||
|
nn.Linear(hidden_dim // 2, 2),
|
||||||
|
nn.Softplus(),
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, node_emb: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""Return non-negative per-body DOF [N, 2]."""
|
||||||
|
return self.mlp(node_emb)
|
||||||
161
solver/models/losses.py
Normal file
161
solver/models/losses.py
Normal file
@@ -0,0 +1,161 @@
|
|||||||
|
"""Uncertainty-weighted multi-task loss (Kendall et al., 2018)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
__all__ = ["MultiTaskLoss"]
|
||||||
|
|
||||||
|
|
||||||
|
class MultiTaskLoss(nn.Module):
|
||||||
|
"""Multi-task loss with learnable uncertainty weighting.
|
||||||
|
|
||||||
|
Each task has a learnable ``log_var`` parameter (log variance) that
|
||||||
|
automatically balances task contributions during training. The loss
|
||||||
|
for task *i* is::
|
||||||
|
|
||||||
|
(1 / (2 * sigma_i^2)) * weight_i * L_i + 0.5 * log(sigma_i^2)
|
||||||
|
|
||||||
|
which simplifies to::
|
||||||
|
|
||||||
|
exp(-log_var_i) * weight_i * L_i + 0.5 * log_var_i
|
||||||
|
|
||||||
|
Args:
|
||||||
|
edge_weight: Initial scale for edge classification loss.
|
||||||
|
graph_weight: Initial scale for graph classification loss.
|
||||||
|
joint_type_weight: Initial scale for joint type loss.
|
||||||
|
dof_weight: Initial scale for DOF regression loss.
|
||||||
|
body_dof_weight: Initial scale for per-body DOF loss.
|
||||||
|
redundant_penalty: Extra weight on redundant edges (label=0) in
|
||||||
|
the edge BCE loss.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
edge_weight: float = 1.0,
|
||||||
|
graph_weight: float = 0.5,
|
||||||
|
joint_type_weight: float = 0.3,
|
||||||
|
dof_weight: float = 0.2,
|
||||||
|
body_dof_weight: float = 0.2,
|
||||||
|
redundant_penalty: float = 2.0,
|
||||||
|
) -> None:
|
||||||
|
super().__init__()
|
||||||
|
self.weights = {
|
||||||
|
"edge": edge_weight,
|
||||||
|
"graph": graph_weight,
|
||||||
|
"joint_type": joint_type_weight,
|
||||||
|
"dof": dof_weight,
|
||||||
|
"body_dof": body_dof_weight,
|
||||||
|
}
|
||||||
|
self.redundant_penalty = redundant_penalty
|
||||||
|
|
||||||
|
# Learnable log-variance parameters, one per task.
|
||||||
|
# Initialized to 0 → sigma^2 = 1.
|
||||||
|
self.log_vars = nn.ParameterDict(
|
||||||
|
{name: nn.Parameter(torch.zeros(1)) for name in self.weights}
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
predictions: dict[str, torch.Tensor],
|
||||||
|
targets: dict[str, torch.Tensor],
|
||||||
|
) -> tuple[torch.Tensor, dict[str, torch.Tensor]]:
|
||||||
|
"""Compute total loss and per-task breakdown.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: Dict with keys from AssemblyGNN output:
|
||||||
|
``edge_pred``, ``graph_pred``, ``joint_type_pred``,
|
||||||
|
``dof_pred``, ``body_dof_pred``.
|
||||||
|
targets: Dict with label tensors:
|
||||||
|
``y_edge``, ``y_graph``, ``y_joint_type``,
|
||||||
|
``y_dof``, ``y_body_dof``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
total_loss: Scalar total loss.
|
||||||
|
breakdown: Dict of per-task raw loss values (before weighting).
|
||||||
|
"""
|
||||||
|
total = torch.tensor(0.0, device=self._device(predictions))
|
||||||
|
breakdown: dict[str, torch.Tensor] = {}
|
||||||
|
|
||||||
|
# Edge classification (BCE with asymmetric redundancy penalty).
|
||||||
|
if "edge_pred" in predictions and "y_edge" in targets:
|
||||||
|
loss = self._edge_loss(predictions["edge_pred"], targets["y_edge"])
|
||||||
|
total = total + self._weighted(loss, "edge")
|
||||||
|
breakdown["edge"] = loss.detach()
|
||||||
|
|
||||||
|
# Graph classification.
|
||||||
|
if "graph_pred" in predictions and "y_graph" in targets:
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
predictions["graph_pred"],
|
||||||
|
targets["y_graph"],
|
||||||
|
)
|
||||||
|
total = total + self._weighted(loss, "graph")
|
||||||
|
breakdown["graph"] = loss.detach()
|
||||||
|
|
||||||
|
# Joint type classification.
|
||||||
|
if "joint_type_pred" in predictions and "y_joint_type" in targets:
|
||||||
|
loss = nn.functional.cross_entropy(
|
||||||
|
predictions["joint_type_pred"],
|
||||||
|
targets["y_joint_type"],
|
||||||
|
)
|
||||||
|
total = total + self._weighted(loss, "joint_type")
|
||||||
|
breakdown["joint_type"] = loss.detach()
|
||||||
|
|
||||||
|
# DOF regression.
|
||||||
|
if "dof_pred" in predictions and "y_dof" in targets:
|
||||||
|
loss = nn.functional.smooth_l1_loss(
|
||||||
|
predictions["dof_pred"],
|
||||||
|
targets["y_dof"],
|
||||||
|
)
|
||||||
|
total = total + self._weighted(loss, "dof")
|
||||||
|
breakdown["dof"] = loss.detach()
|
||||||
|
|
||||||
|
# Per-body DOF tracking.
|
||||||
|
if "body_dof_pred" in predictions and "y_body_dof" in targets:
|
||||||
|
loss = nn.functional.smooth_l1_loss(
|
||||||
|
predictions["body_dof_pred"],
|
||||||
|
targets["y_body_dof"],
|
||||||
|
)
|
||||||
|
total = total + self._weighted(loss, "body_dof")
|
||||||
|
breakdown["body_dof"] = loss.detach()
|
||||||
|
|
||||||
|
return total, breakdown
|
||||||
|
|
||||||
|
def _edge_loss(
|
||||||
|
self,
|
||||||
|
pred: torch.Tensor,
|
||||||
|
target: torch.Tensor,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""BCE loss with asymmetric weighting for redundant edges."""
|
||||||
|
# pred: [E, 1] logits, target: [E] binary.
|
||||||
|
pred_flat = pred.squeeze(-1)
|
||||||
|
# Weight: redundant (0) gets higher penalty, independent (1) gets 1.0.
|
||||||
|
weight = torch.where(
|
||||||
|
target == 0,
|
||||||
|
torch.tensor(self.redundant_penalty, device=pred.device),
|
||||||
|
torch.tensor(1.0, device=pred.device),
|
||||||
|
)
|
||||||
|
return nn.functional.binary_cross_entropy_with_logits(
|
||||||
|
pred_flat,
|
||||||
|
target,
|
||||||
|
weight=weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _weighted(self, loss: torch.Tensor, task_name: str) -> torch.Tensor:
|
||||||
|
"""Apply uncertainty weighting: exp(-log_var) * w * L + 0.5 * log_var."""
|
||||||
|
log_var = self.log_vars[task_name].squeeze()
|
||||||
|
weight = self.weights[task_name]
|
||||||
|
return torch.exp(-log_var) * weight * loss + 0.5 * log_var
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _device(predictions: dict[str, torch.Tensor]) -> torch.device:
|
||||||
|
"""Infer device from prediction tensors."""
|
||||||
|
for v in predictions.values():
|
||||||
|
return v.device
|
||||||
|
return torch.device("cpu")
|
||||||
@@ -1,287 +0,0 @@
|
|||||||
"""Tests for solver.mates.conversion -- mate-to-joint conversion."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from solver.datagen.types import JointType, RigidBody
|
|
||||||
from solver.mates.conversion import (
|
|
||||||
MateAnalysisResult,
|
|
||||||
analyze_mate_assembly,
|
|
||||||
convert_mates_to_joints,
|
|
||||||
)
|
|
||||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ref(
|
|
||||||
body_id: int,
|
|
||||||
geom_type: GeometryType,
|
|
||||||
*,
|
|
||||||
origin: np.ndarray | None = None,
|
|
||||||
direction: np.ndarray | None = None,
|
|
||||||
) -> GeometryRef:
|
|
||||||
"""Factory for GeometryRef with sensible defaults."""
|
|
||||||
if origin is None:
|
|
||||||
origin = np.zeros(3)
|
|
||||||
if direction is None and geom_type in {
|
|
||||||
GeometryType.FACE,
|
|
||||||
GeometryType.AXIS,
|
|
||||||
GeometryType.PLANE,
|
|
||||||
}:
|
|
||||||
direction = np.array([0.0, 0.0, 1.0])
|
|
||||||
return GeometryRef(
|
|
||||||
body_id=body_id,
|
|
||||||
geometry_type=geom_type,
|
|
||||||
geometry_id="Geom001",
|
|
||||||
origin=origin,
|
|
||||||
direction=direction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_bodies(n: int) -> list[RigidBody]:
|
|
||||||
"""Create n bodies at distinct positions."""
|
|
||||||
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# convert_mates_to_joints
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestConvertMatesToJoints:
|
|
||||||
"""convert_mates_to_joints function."""
|
|
||||||
|
|
||||||
def test_empty_input(self) -> None:
|
|
||||||
joints, m2j, j2m = convert_mates_to_joints([])
|
|
||||||
assert joints == []
|
|
||||||
assert m2j == {}
|
|
||||||
assert j2m == {}
|
|
||||||
|
|
||||||
def test_hinge_pattern(self) -> None:
|
|
||||||
"""Concentric + Coincident(plane) -> single REVOLUTE joint."""
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
joints, m2j, j2m = convert_mates_to_joints(mates)
|
|
||||||
assert len(joints) == 1
|
|
||||||
assert joints[0].joint_type is JointType.REVOLUTE
|
|
||||||
assert joints[0].body_a == 0
|
|
||||||
assert joints[0].body_b == 1
|
|
||||||
# Both mates map to the single joint
|
|
||||||
assert 0 in m2j
|
|
||||||
assert 1 in m2j
|
|
||||||
assert j2m[joints[0].joint_id] == [0, 1]
|
|
||||||
|
|
||||||
def test_lock_pattern(self) -> None:
|
|
||||||
"""Lock -> FIXED joint."""
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
joints, _m2j, _j2m = convert_mates_to_joints(mates)
|
|
||||||
assert len(joints) == 1
|
|
||||||
assert joints[0].joint_type is JointType.FIXED
|
|
||||||
|
|
||||||
def test_unmatched_mate_fallback(self) -> None:
|
|
||||||
"""A single ANGLE mate with no pattern -> individual joint."""
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.ANGLE,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
joints, _m2j, _j2m = convert_mates_to_joints(mates)
|
|
||||||
assert len(joints) == 1
|
|
||||||
assert joints[0].joint_type is JointType.PERPENDICULAR
|
|
||||||
|
|
||||||
def test_mapping_consistency(self) -> None:
|
|
||||||
"""mate_to_joint and joint_to_mates are consistent."""
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=2,
|
|
||||||
mate_type=MateType.DISTANCE,
|
|
||||||
ref_a=_make_ref(2, GeometryType.POINT),
|
|
||||||
ref_b=_make_ref(3, GeometryType.POINT),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
joints, m2j, j2m = convert_mates_to_joints(mates)
|
|
||||||
# Every mate should be in m2j
|
|
||||||
for mate in mates:
|
|
||||||
assert mate.mate_id in m2j
|
|
||||||
# Every joint should be in j2m
|
|
||||||
for joint in joints:
|
|
||||||
assert joint.joint_id in j2m
|
|
||||||
|
|
||||||
def test_joint_axis_from_geometry(self) -> None:
|
|
||||||
"""Joint axis should come from mate geometry direction."""
|
|
||||||
axis_dir = np.array([1.0, 0.0, 0.0])
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS, direction=axis_dir),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS, direction=axis_dir),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
joints, _, _ = convert_mates_to_joints(mates)
|
|
||||||
np.testing.assert_array_almost_equal(joints[0].axis, axis_dir)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MateAnalysisResult
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMateAnalysisResult:
|
|
||||||
"""MateAnalysisResult dataclass."""
|
|
||||||
|
|
||||||
def test_to_dict(self) -> None:
|
|
||||||
result = MateAnalysisResult(
|
|
||||||
patterns=[],
|
|
||||||
joints=[],
|
|
||||||
)
|
|
||||||
d = result.to_dict()
|
|
||||||
assert d["patterns"] == []
|
|
||||||
assert d["joints"] == []
|
|
||||||
assert d["labels"] is None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# analyze_mate_assembly
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestAnalyzeMateAssembly:
|
|
||||||
"""Full pipeline: mates -> joints -> analysis."""
|
|
||||||
|
|
||||||
def test_two_bodies_hinge(self) -> None:
|
|
||||||
"""Two bodies connected by hinge mates -> underconstrained (1 DOF)."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = analyze_mate_assembly(bodies, mates)
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert result.labels is not None
|
|
||||||
# A revolute joint removes 5 DOF, leaving 1 internal DOF
|
|
||||||
assert result.analysis.combinatorial_internal_dof == 1
|
|
||||||
assert len(result.joints) == 1
|
|
||||||
assert result.joints[0].joint_type is JointType.REVOLUTE
|
|
||||||
|
|
||||||
def test_two_bodies_fixed(self) -> None:
|
|
||||||
"""Two bodies with lock mate -> well-constrained."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = analyze_mate_assembly(bodies, mates)
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert result.analysis.combinatorial_internal_dof == 0
|
|
||||||
assert result.analysis.is_rigid
|
|
||||||
|
|
||||||
def test_grounded_assembly(self) -> None:
|
|
||||||
"""Grounded assembly analysis works."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = analyze_mate_assembly(bodies, mates, ground_body=0)
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert result.analysis.is_rigid
|
|
||||||
|
|
||||||
def test_no_mates(self) -> None:
|
|
||||||
"""Assembly with no mates should be fully underconstrained."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
result = analyze_mate_assembly(bodies, [])
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert result.analysis.combinatorial_internal_dof == 6
|
|
||||||
assert len(result.joints) == 0
|
|
||||||
|
|
||||||
def test_single_body(self) -> None:
|
|
||||||
"""Single body, no mates."""
|
|
||||||
bodies = _make_bodies(1)
|
|
||||||
result = analyze_mate_assembly(bodies, [])
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert len(result.joints) == 0
|
|
||||||
|
|
||||||
def test_result_traceability(self) -> None:
|
|
||||||
"""mate_to_joint and joint_to_mates populated in result."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = analyze_mate_assembly(bodies, mates)
|
|
||||||
assert 0 in result.mate_to_joint
|
|
||||||
assert 1 in result.mate_to_joint
|
|
||||||
assert len(result.joint_to_mates) > 0
|
|
||||||
@@ -1,155 +0,0 @@
|
|||||||
"""Tests for solver.mates.generator -- synthetic mate generator."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from solver.mates.generator import SyntheticMateGenerator, generate_mate_training_batch
|
|
||||||
from solver.mates.primitives import MateType
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# SyntheticMateGenerator
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestSyntheticMateGenerator:
|
|
||||||
"""SyntheticMateGenerator core functionality."""
|
|
||||||
|
|
||||||
def test_generate_basic(self) -> None:
|
|
||||||
"""Generate a simple assembly with mates."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42)
|
|
||||||
bodies, mates, result = gen.generate(3)
|
|
||||||
assert len(bodies) == 3
|
|
||||||
assert len(mates) > 0
|
|
||||||
assert result.analysis is not None
|
|
||||||
|
|
||||||
def test_deterministic_with_seed(self) -> None:
|
|
||||||
"""Same seed produces same output."""
|
|
||||||
gen1 = SyntheticMateGenerator(seed=123)
|
|
||||||
_, mates1, _ = gen1.generate(3)
|
|
||||||
|
|
||||||
gen2 = SyntheticMateGenerator(seed=123)
|
|
||||||
_, mates2, _ = gen2.generate(3)
|
|
||||||
|
|
||||||
assert len(mates1) == len(mates2)
|
|
||||||
for m1, m2 in zip(mates1, mates2, strict=True):
|
|
||||||
assert m1.mate_type == m2.mate_type
|
|
||||||
assert m1.ref_a.body_id == m2.ref_a.body_id
|
|
||||||
|
|
||||||
def test_grounded(self) -> None:
|
|
||||||
"""Grounded assembly should work."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42)
|
|
||||||
bodies, _mates, result = gen.generate(3, grounded=True)
|
|
||||||
assert len(bodies) == 3
|
|
||||||
assert result.analysis is not None
|
|
||||||
|
|
||||||
def test_revolute_produces_two_mates(self) -> None:
|
|
||||||
"""A revolute joint should reverse-map to 2 mates."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42)
|
|
||||||
_bodies, mates, _result = gen.generate(2)
|
|
||||||
# 2 bodies -> 1 revolute joint -> 2 mates (concentric + coincident)
|
|
||||||
assert len(mates) == 2
|
|
||||||
mate_types = {m.mate_type for m in mates}
|
|
||||||
assert MateType.CONCENTRIC in mate_types
|
|
||||||
assert MateType.COINCIDENT in mate_types
|
|
||||||
|
|
||||||
|
|
||||||
class TestReverseMapping:
|
|
||||||
"""Reverse mapping from joints to mates."""
|
|
||||||
|
|
||||||
def test_revolute_mapping(self) -> None:
|
|
||||||
"""REVOLUTE -> Concentric + Coincident."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42)
|
|
||||||
_bodies, mates, _result = gen.generate(2)
|
|
||||||
types = [m.mate_type for m in mates]
|
|
||||||
assert MateType.CONCENTRIC in types
|
|
||||||
assert MateType.COINCIDENT in types
|
|
||||||
|
|
||||||
def test_round_trip_analysis(self) -> None:
|
|
||||||
"""Generated mates round-trip through analysis successfully."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42)
|
|
||||||
_bodies, _mates, result = gen.generate(4)
|
|
||||||
assert result.analysis is not None
|
|
||||||
assert result.labels is not None
|
|
||||||
# Should produce joints from the mates
|
|
||||||
assert len(result.joints) > 0
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoiseInjection:
|
|
||||||
"""Noise injection mechanisms."""
|
|
||||||
|
|
||||||
def test_redundant_injection(self) -> None:
|
|
||||||
"""Redundant prob > 0 produces more mates than clean version."""
|
|
||||||
gen_clean = SyntheticMateGenerator(seed=42, redundant_prob=0.0)
|
|
||||||
_, mates_clean, _ = gen_clean.generate(4)
|
|
||||||
|
|
||||||
gen_noisy = SyntheticMateGenerator(seed=42, redundant_prob=1.0)
|
|
||||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
|
||||||
|
|
||||||
assert len(mates_noisy) > len(mates_clean)
|
|
||||||
|
|
||||||
def test_missing_injection(self) -> None:
|
|
||||||
"""Missing prob > 0 produces fewer mates than clean version."""
|
|
||||||
gen_clean = SyntheticMateGenerator(seed=42, missing_prob=0.0)
|
|
||||||
_, mates_clean, _ = gen_clean.generate(4)
|
|
||||||
|
|
||||||
gen_noisy = SyntheticMateGenerator(seed=42, missing_prob=0.5)
|
|
||||||
_, mates_noisy, _ = gen_noisy.generate(4)
|
|
||||||
|
|
||||||
# With 50% drop rate on 6 mates, very likely to drop at least one
|
|
||||||
assert len(mates_noisy) <= len(mates_clean)
|
|
||||||
|
|
||||||
def test_incompatible_injection(self) -> None:
|
|
||||||
"""Incompatible prob > 0 adds mates with wrong geometry."""
|
|
||||||
gen = SyntheticMateGenerator(seed=42, incompatible_prob=1.0)
|
|
||||||
_, mates, _ = gen.generate(3)
|
|
||||||
# Should have extra mates beyond the clean count
|
|
||||||
gen_clean = SyntheticMateGenerator(seed=42)
|
|
||||||
_, mates_clean, _ = gen_clean.generate(3)
|
|
||||||
assert len(mates) > len(mates_clean)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# generate_mate_training_batch
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestGenerateMateTrainingBatch:
|
|
||||||
"""Batch generation function."""
|
|
||||||
|
|
||||||
def test_batch_structure(self) -> None:
|
|
||||||
"""Each example has required keys."""
|
|
||||||
examples = generate_mate_training_batch(batch_size=3, seed=42)
|
|
||||||
assert len(examples) == 3
|
|
||||||
for ex in examples:
|
|
||||||
assert "bodies" in ex
|
|
||||||
assert "mates" in ex
|
|
||||||
assert "patterns" in ex
|
|
||||||
assert "labels" in ex
|
|
||||||
assert "n_bodies" in ex
|
|
||||||
assert "n_mates" in ex
|
|
||||||
assert "n_joints" in ex
|
|
||||||
|
|
||||||
def test_batch_deterministic(self) -> None:
|
|
||||||
"""Same seed produces same batch."""
|
|
||||||
batch1 = generate_mate_training_batch(batch_size=5, seed=99)
|
|
||||||
batch2 = generate_mate_training_batch(batch_size=5, seed=99)
|
|
||||||
for ex1, ex2 in zip(batch1, batch2, strict=True):
|
|
||||||
assert ex1["n_bodies"] == ex2["n_bodies"]
|
|
||||||
assert ex1["n_mates"] == ex2["n_mates"]
|
|
||||||
|
|
||||||
def test_batch_grounded_ratio(self) -> None:
|
|
||||||
"""Batch respects grounded_ratio parameter."""
|
|
||||||
# All grounded
|
|
||||||
examples = generate_mate_training_batch(batch_size=5, seed=42, grounded_ratio=1.0)
|
|
||||||
assert len(examples) == 5
|
|
||||||
|
|
||||||
def test_batch_with_noise(self) -> None:
|
|
||||||
"""Batch with noise injection runs without error."""
|
|
||||||
examples = generate_mate_training_batch(
|
|
||||||
batch_size=3,
|
|
||||||
seed=42,
|
|
||||||
redundant_prob=0.3,
|
|
||||||
missing_prob=0.1,
|
|
||||||
)
|
|
||||||
assert len(examples) == 3
|
|
||||||
for ex in examples:
|
|
||||||
assert ex["n_mates"] >= 0
|
|
||||||
@@ -1,224 +0,0 @@
|
|||||||
"""Tests for solver.mates.labeling -- mate-level ground truth labels."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from solver.datagen.types import RigidBody
|
|
||||||
from solver.mates.labeling import MateAssemblyLabels, MateLabel, label_mate_assembly
|
|
||||||
from solver.mates.patterns import JointPattern
|
|
||||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ref(
|
|
||||||
body_id: int,
|
|
||||||
geom_type: GeometryType,
|
|
||||||
*,
|
|
||||||
origin: np.ndarray | None = None,
|
|
||||||
direction: np.ndarray | None = None,
|
|
||||||
) -> GeometryRef:
|
|
||||||
"""Factory for GeometryRef with sensible defaults."""
|
|
||||||
if origin is None:
|
|
||||||
origin = np.zeros(3)
|
|
||||||
if direction is None and geom_type in {
|
|
||||||
GeometryType.FACE,
|
|
||||||
GeometryType.AXIS,
|
|
||||||
GeometryType.PLANE,
|
|
||||||
}:
|
|
||||||
direction = np.array([0.0, 0.0, 1.0])
|
|
||||||
return GeometryRef(
|
|
||||||
body_id=body_id,
|
|
||||||
geometry_type=geom_type,
|
|
||||||
geometry_id="Geom001",
|
|
||||||
origin=origin,
|
|
||||||
direction=direction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_bodies(n: int) -> list[RigidBody]:
|
|
||||||
"""Create n bodies at distinct positions."""
|
|
||||||
return [RigidBody(body_id=i, position=np.array([float(i), 0.0, 0.0])) for i in range(n)]
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MateLabel
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMateLabel:
|
|
||||||
"""MateLabel dataclass."""
|
|
||||||
|
|
||||||
def test_defaults(self) -> None:
|
|
||||||
ml = MateLabel(mate_id=0)
|
|
||||||
assert ml.is_independent is True
|
|
||||||
assert ml.is_redundant is False
|
|
||||||
assert ml.is_degenerate is False
|
|
||||||
assert ml.pattern is None
|
|
||||||
assert ml.issue is None
|
|
||||||
|
|
||||||
def test_to_dict(self) -> None:
|
|
||||||
ml = MateLabel(
|
|
||||||
mate_id=5,
|
|
||||||
is_independent=False,
|
|
||||||
is_redundant=True,
|
|
||||||
pattern=JointPattern.HINGE,
|
|
||||||
issue="redundant",
|
|
||||||
)
|
|
||||||
d = ml.to_dict()
|
|
||||||
assert d["mate_id"] == 5
|
|
||||||
assert d["is_redundant"] is True
|
|
||||||
assert d["pattern"] == "hinge"
|
|
||||||
assert d["issue"] == "redundant"
|
|
||||||
|
|
||||||
def test_to_dict_none_pattern(self) -> None:
|
|
||||||
ml = MateLabel(mate_id=0)
|
|
||||||
d = ml.to_dict()
|
|
||||||
assert d["pattern"] is None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MateAssemblyLabels
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMateAssemblyLabels:
|
|
||||||
"""MateAssemblyLabels dataclass."""
|
|
||||||
|
|
||||||
def test_to_dict_structure(self) -> None:
|
|
||||||
"""to_dict produces expected keys."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates)
|
|
||||||
d = result.to_dict()
|
|
||||||
assert "per_mate" in d
|
|
||||||
assert "patterns" in d
|
|
||||||
assert "assembly" in d
|
|
||||||
assert isinstance(d["per_mate"], list)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# label_mate_assembly
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestLabelMateAssembly:
|
|
||||||
"""Full labeling pipeline."""
|
|
||||||
|
|
||||||
def test_clean_assembly_no_redundancy(self) -> None:
|
|
||||||
"""Two bodies with lock mate -> clean, no redundancy."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates)
|
|
||||||
assert isinstance(result, MateAssemblyLabels)
|
|
||||||
assert len(result.per_mate) == 1
|
|
||||||
ml = result.per_mate[0]
|
|
||||||
assert ml.mate_id == 0
|
|
||||||
assert ml.is_independent is True
|
|
||||||
assert ml.is_redundant is False
|
|
||||||
assert ml.issue is None
|
|
||||||
|
|
||||||
def test_redundant_assembly(self) -> None:
|
|
||||||
"""Two lock mates on same body pair -> one is redundant."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE, origin=np.array([1.0, 0.0, 0.0])),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates)
|
|
||||||
assert len(result.per_mate) == 2
|
|
||||||
redundant_count = sum(1 for ml in result.per_mate if ml.is_redundant)
|
|
||||||
# At least one should be redundant
|
|
||||||
assert redundant_count >= 1
|
|
||||||
assert result.assembly.redundant_count > 0
|
|
||||||
|
|
||||||
def test_hinge_pattern_labeling(self) -> None:
|
|
||||||
"""Hinge mates get pattern membership."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=_make_ref(0, GeometryType.AXIS),
|
|
||||||
ref_b=_make_ref(1, GeometryType.AXIS),
|
|
||||||
),
|
|
||||||
Mate(
|
|
||||||
mate_id=1,
|
|
||||||
mate_type=MateType.COINCIDENT,
|
|
||||||
ref_a=_make_ref(0, GeometryType.PLANE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.PLANE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates)
|
|
||||||
assert len(result.per_mate) == 2
|
|
||||||
# Both mates should be part of the hinge pattern
|
|
||||||
for ml in result.per_mate:
|
|
||||||
assert ml.pattern is JointPattern.HINGE
|
|
||||||
assert ml.is_independent is True
|
|
||||||
|
|
||||||
def test_grounded_assembly(self) -> None:
|
|
||||||
"""Grounded assembly labeling works."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates, ground_body=0)
|
|
||||||
assert result.assembly.is_rigid
|
|
||||||
|
|
||||||
def test_empty_mates(self) -> None:
|
|
||||||
"""No mates -> no per_mate labels, underconstrained."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
result = label_mate_assembly(bodies, [])
|
|
||||||
assert len(result.per_mate) == 0
|
|
||||||
assert result.assembly.classification == "underconstrained"
|
|
||||||
|
|
||||||
def test_assembly_classification(self) -> None:
|
|
||||||
"""Assembly classification is present."""
|
|
||||||
bodies = _make_bodies(2)
|
|
||||||
mates = [
|
|
||||||
Mate(
|
|
||||||
mate_id=0,
|
|
||||||
mate_type=MateType.LOCK,
|
|
||||||
ref_a=_make_ref(0, GeometryType.FACE),
|
|
||||||
ref_b=_make_ref(1, GeometryType.FACE),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
result = label_mate_assembly(bodies, mates)
|
|
||||||
assert result.assembly.classification in {
|
|
||||||
"well-constrained",
|
|
||||||
"overconstrained",
|
|
||||||
"underconstrained",
|
|
||||||
"mixed",
|
|
||||||
}
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
"""Tests for solver.mates.patterns -- joint pattern recognition."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from solver.datagen.types import JointType
|
|
||||||
from solver.mates.patterns import JointPattern, PatternMatch, recognize_patterns
|
|
||||||
from solver.mates.primitives import GeometryRef, GeometryType, Mate, MateType
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ref(
|
|
||||||
body_id: int,
|
|
||||||
geom_type: GeometryType,
|
|
||||||
*,
|
|
||||||
geometry_id: str = "Geom001",
|
|
||||||
origin: np.ndarray | None = None,
|
|
||||||
direction: np.ndarray | None = None,
|
|
||||||
) -> GeometryRef:
|
|
||||||
"""Factory for GeometryRef with sensible defaults."""
|
|
||||||
if origin is None:
|
|
||||||
origin = np.zeros(3)
|
|
||||||
if direction is None and geom_type in {
|
|
||||||
GeometryType.FACE,
|
|
||||||
GeometryType.AXIS,
|
|
||||||
GeometryType.PLANE,
|
|
||||||
}:
|
|
||||||
direction = np.array([0.0, 0.0, 1.0])
|
|
||||||
return GeometryRef(
|
|
||||||
body_id=body_id,
|
|
||||||
geometry_type=geom_type,
|
|
||||||
geometry_id=geometry_id,
|
|
||||||
origin=origin,
|
|
||||||
direction=direction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_mate(
|
|
||||||
mate_id: int,
|
|
||||||
mate_type: MateType,
|
|
||||||
body_a: int,
|
|
||||||
body_b: int,
|
|
||||||
geom_a: GeometryType = GeometryType.FACE,
|
|
||||||
geom_b: GeometryType = GeometryType.FACE,
|
|
||||||
) -> Mate:
|
|
||||||
"""Factory for Mate with body pair and geometry types."""
|
|
||||||
return Mate(
|
|
||||||
mate_id=mate_id,
|
|
||||||
mate_type=mate_type,
|
|
||||||
ref_a=_make_ref(body_a, geom_a),
|
|
||||||
ref_b=_make_ref(body_b, geom_b),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# JointPattern enum
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestJointPattern:
|
|
||||||
"""JointPattern enum."""
|
|
||||||
|
|
||||||
def test_member_count(self) -> None:
|
|
||||||
assert len(JointPattern) == 9
|
|
||||||
|
|
||||||
def test_string_values(self) -> None:
|
|
||||||
for jp in JointPattern:
|
|
||||||
assert isinstance(jp.value, str)
|
|
||||||
|
|
||||||
def test_access_by_name(self) -> None:
|
|
||||||
assert JointPattern["HINGE"] is JointPattern.HINGE
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# PatternMatch
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestPatternMatch:
|
|
||||||
"""PatternMatch dataclass."""
|
|
||||||
|
|
||||||
def test_construction(self) -> None:
|
|
||||||
mate = _make_mate(0, MateType.LOCK, 0, 1)
|
|
||||||
pm = PatternMatch(
|
|
||||||
pattern=JointPattern.FIXED,
|
|
||||||
mates=[mate],
|
|
||||||
body_a=0,
|
|
||||||
body_b=1,
|
|
||||||
confidence=1.0,
|
|
||||||
equivalent_joint_type=JointType.FIXED,
|
|
||||||
)
|
|
||||||
assert pm.pattern is JointPattern.FIXED
|
|
||||||
assert pm.confidence == 1.0
|
|
||||||
assert pm.missing_mates == []
|
|
||||||
|
|
||||||
def test_to_dict(self) -> None:
|
|
||||||
mate = _make_mate(5, MateType.LOCK, 0, 1)
|
|
||||||
pm = PatternMatch(
|
|
||||||
pattern=JointPattern.FIXED,
|
|
||||||
mates=[mate],
|
|
||||||
body_a=0,
|
|
||||||
body_b=1,
|
|
||||||
confidence=1.0,
|
|
||||||
equivalent_joint_type=JointType.FIXED,
|
|
||||||
)
|
|
||||||
d = pm.to_dict()
|
|
||||||
assert d["pattern"] == "fixed"
|
|
||||||
assert d["mate_ids"] == [5]
|
|
||||||
assert d["equivalent_joint_type"] == "FIXED"
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# recognize_patterns — canonical patterns
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecognizeCanonical:
|
|
||||||
"""Full-confidence canonical pattern recognition."""
|
|
||||||
|
|
||||||
def test_empty_input(self) -> None:
|
|
||||||
assert recognize_patterns([]) == []
|
|
||||||
|
|
||||||
def test_hinge(self) -> None:
|
|
||||||
"""Concentric(axis) + Coincident(plane) -> Hinge."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
top = results[0]
|
|
||||||
assert top.pattern is JointPattern.HINGE
|
|
||||||
assert top.confidence == 1.0
|
|
||||||
assert top.equivalent_joint_type is JointType.REVOLUTE
|
|
||||||
assert top.missing_mates == []
|
|
||||||
|
|
||||||
def test_slider(self) -> None:
|
|
||||||
"""Coincident(plane) + Parallel(axis) -> Slider."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
_make_mate(1, MateType.PARALLEL, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
top = results[0]
|
|
||||||
assert top.pattern is JointPattern.SLIDER
|
|
||||||
assert top.confidence == 1.0
|
|
||||||
assert top.equivalent_joint_type is JointType.SLIDER
|
|
||||||
|
|
||||||
def test_cylinder(self) -> None:
|
|
||||||
"""Concentric(axis) only -> Cylinder."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
# Should match cylinder at confidence 1.0
|
|
||||||
cylinder = [r for r in results if r.pattern is JointPattern.CYLINDER]
|
|
||||||
assert len(cylinder) >= 1
|
|
||||||
assert cylinder[0].confidence == 1.0
|
|
||||||
assert cylinder[0].equivalent_joint_type is JointType.CYLINDRICAL
|
|
||||||
|
|
||||||
def test_ball(self) -> None:
|
|
||||||
"""Coincident(point) -> Ball."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.POINT, GeometryType.POINT),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
top = results[0]
|
|
||||||
assert top.pattern is JointPattern.BALL
|
|
||||||
assert top.confidence == 1.0
|
|
||||||
assert top.equivalent_joint_type is JointType.BALL
|
|
||||||
|
|
||||||
def test_planar_face(self) -> None:
|
|
||||||
"""Coincident(face) -> Planar."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
top = results[0]
|
|
||||||
assert top.pattern is JointPattern.PLANAR
|
|
||||||
assert top.confidence == 1.0
|
|
||||||
assert top.equivalent_joint_type is JointType.PLANAR
|
|
||||||
|
|
||||||
def test_fixed(self) -> None:
|
|
||||||
"""Lock -> Fixed."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.LOCK, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
top = results[0]
|
|
||||||
assert top.pattern is JointPattern.FIXED
|
|
||||||
assert top.confidence == 1.0
|
|
||||||
assert top.equivalent_joint_type is JointType.FIXED
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# recognize_patterns — partial matches
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecognizePartial:
|
|
||||||
"""Partial pattern matches and hints."""
|
|
||||||
|
|
||||||
def test_concentric_without_plane_hints_hinge(self) -> None:
|
|
||||||
"""Concentric alone matches hinge at 0.5 confidence with missing hint."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
|
|
||||||
assert len(hinge_matches) >= 1
|
|
||||||
hinge = hinge_matches[0]
|
|
||||||
assert hinge.confidence == 0.5
|
|
||||||
assert len(hinge.missing_mates) > 0
|
|
||||||
|
|
||||||
def test_coincident_plane_without_parallel_hints_slider(self) -> None:
|
|
||||||
"""Coincident(plane) alone matches slider at 0.5 confidence."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
slider_matches = [r for r in results if r.pattern is JointPattern.SLIDER]
|
|
||||||
assert len(slider_matches) >= 1
|
|
||||||
assert slider_matches[0].confidence == 0.5
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# recognize_patterns — ambiguous / multi-body
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestRecognizeAmbiguous:
|
|
||||||
"""Ambiguous patterns and multi-body-pair assemblies."""
|
|
||||||
|
|
||||||
def test_concentric_matches_both_hinge_and_cylinder(self) -> None:
|
|
||||||
"""A single concentric mate produces both hinge (partial) and cylinder matches."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
patterns = {r.pattern for r in results}
|
|
||||||
assert JointPattern.HINGE in patterns
|
|
||||||
assert JointPattern.CYLINDER in patterns
|
|
||||||
|
|
||||||
def test_multiple_body_pairs(self) -> None:
|
|
||||||
"""Mates across different body pairs produce separate pattern matches."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.LOCK, 0, 1),
|
|
||||||
_make_mate(1, MateType.COINCIDENT, 2, 3, GeometryType.POINT, GeometryType.POINT),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
pairs = {(r.body_a, r.body_b) for r in results}
|
|
||||||
assert (0, 1) in pairs
|
|
||||||
assert (2, 3) in pairs
|
|
||||||
|
|
||||||
def test_results_sorted_by_confidence(self) -> None:
|
|
||||||
"""All results should be sorted by confidence descending."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 0, 1, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
_make_mate(1, MateType.LOCK, 2, 3),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
confidences = [r.confidence for r in results]
|
|
||||||
assert confidences == sorted(confidences, reverse=True)
|
|
||||||
|
|
||||||
def test_unknown_pattern(self) -> None:
|
|
||||||
"""A mate type that matches no rule returns UNKNOWN."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.ANGLE, 0, 1, GeometryType.FACE, GeometryType.FACE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
assert any(r.pattern is JointPattern.UNKNOWN for r in results)
|
|
||||||
|
|
||||||
def test_body_pair_normalization(self) -> None:
|
|
||||||
"""Mates with reversed body order should be grouped together."""
|
|
||||||
mates = [
|
|
||||||
_make_mate(0, MateType.CONCENTRIC, 1, 0, GeometryType.AXIS, GeometryType.AXIS),
|
|
||||||
_make_mate(1, MateType.COINCIDENT, 0, 1, GeometryType.PLANE, GeometryType.PLANE),
|
|
||||||
]
|
|
||||||
results = recognize_patterns(mates)
|
|
||||||
hinge_matches = [r for r in results if r.pattern is JointPattern.HINGE]
|
|
||||||
assert len(hinge_matches) >= 1
|
|
||||||
assert hinge_matches[0].confidence == 1.0
|
|
||||||
@@ -1,329 +0,0 @@
|
|||||||
"""Tests for solver.mates.primitives -- mate type definitions."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import ClassVar
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from solver.mates.primitives import (
|
|
||||||
GeometryRef,
|
|
||||||
GeometryType,
|
|
||||||
Mate,
|
|
||||||
MateType,
|
|
||||||
dof_removed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Helpers
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
def _make_ref(
|
|
||||||
body_id: int,
|
|
||||||
geom_type: GeometryType,
|
|
||||||
*,
|
|
||||||
geometry_id: str = "Geom001",
|
|
||||||
origin: np.ndarray | None = None,
|
|
||||||
direction: np.ndarray | None = None,
|
|
||||||
) -> GeometryRef:
|
|
||||||
"""Factory for GeometryRef with sensible defaults."""
|
|
||||||
if origin is None:
|
|
||||||
origin = np.zeros(3)
|
|
||||||
if direction is None and geom_type in {
|
|
||||||
GeometryType.FACE,
|
|
||||||
GeometryType.AXIS,
|
|
||||||
GeometryType.PLANE,
|
|
||||||
}:
|
|
||||||
direction = np.array([0.0, 0.0, 1.0])
|
|
||||||
return GeometryRef(
|
|
||||||
body_id=body_id,
|
|
||||||
geometry_type=geom_type,
|
|
||||||
geometry_id=geometry_id,
|
|
||||||
origin=origin,
|
|
||||||
direction=direction,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# MateType
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMateType:
|
|
||||||
"""MateType enum construction and DOF values."""
|
|
||||||
|
|
||||||
EXPECTED_DOF: ClassVar[dict[str, int]] = {
|
|
||||||
"COINCIDENT": 3,
|
|
||||||
"CONCENTRIC": 2,
|
|
||||||
"PARALLEL": 2,
|
|
||||||
"PERPENDICULAR": 1,
|
|
||||||
"TANGENT": 1,
|
|
||||||
"DISTANCE": 1,
|
|
||||||
"ANGLE": 1,
|
|
||||||
"LOCK": 6,
|
|
||||||
}
|
|
||||||
|
|
||||||
def test_member_count(self) -> None:
|
|
||||||
assert len(MateType) == 8
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name,dof", EXPECTED_DOF.items())
|
|
||||||
def test_default_dof_values(self, name: str, dof: int) -> None:
|
|
||||||
assert MateType[name].default_dof == dof
|
|
||||||
|
|
||||||
def test_value_is_tuple(self) -> None:
|
|
||||||
assert MateType.COINCIDENT.value == (0, 3)
|
|
||||||
assert MateType.COINCIDENT.default_dof == 3
|
|
||||||
|
|
||||||
def test_access_by_name(self) -> None:
|
|
||||||
assert MateType["LOCK"] is MateType.LOCK
|
|
||||||
|
|
||||||
def test_no_alias_collision(self) -> None:
|
|
||||||
ordinals = [m.value[0] for m in MateType]
|
|
||||||
assert len(ordinals) == len(set(ordinals))
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# GeometryType
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestGeometryType:
|
|
||||||
"""GeometryType enum."""
|
|
||||||
|
|
||||||
def test_member_count(self) -> None:
|
|
||||||
assert len(GeometryType) == 5
|
|
||||||
|
|
||||||
def test_string_values(self) -> None:
|
|
||||||
for gt in GeometryType:
|
|
||||||
assert isinstance(gt.value, str)
|
|
||||||
assert gt.value == gt.name.lower()
|
|
||||||
|
|
||||||
def test_access_by_name(self) -> None:
|
|
||||||
assert GeometryType["FACE"] is GeometryType.FACE
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# GeometryRef
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestGeometryRef:
|
|
||||||
"""GeometryRef dataclass."""
|
|
||||||
|
|
||||||
def test_construction(self) -> None:
|
|
||||||
ref = _make_ref(0, GeometryType.AXIS, geometry_id="Axis001")
|
|
||||||
assert ref.body_id == 0
|
|
||||||
assert ref.geometry_type is GeometryType.AXIS
|
|
||||||
assert ref.geometry_id == "Axis001"
|
|
||||||
np.testing.assert_array_equal(ref.origin, np.zeros(3))
|
|
||||||
assert ref.direction is not None
|
|
||||||
|
|
||||||
def test_default_direction_none(self) -> None:
|
|
||||||
ref = GeometryRef(
|
|
||||||
body_id=0,
|
|
||||||
geometry_type=GeometryType.POINT,
|
|
||||||
geometry_id="Point001",
|
|
||||||
)
|
|
||||||
assert ref.direction is None
|
|
||||||
|
|
||||||
def test_to_dict_round_trip(self) -> None:
|
|
||||||
ref = _make_ref(
|
|
||||||
1,
|
|
||||||
GeometryType.FACE,
|
|
||||||
origin=np.array([1.0, 2.0, 3.0]),
|
|
||||||
direction=np.array([0.0, 1.0, 0.0]),
|
|
||||||
)
|
|
||||||
d = ref.to_dict()
|
|
||||||
restored = GeometryRef.from_dict(d)
|
|
||||||
assert restored.body_id == ref.body_id
|
|
||||||
assert restored.geometry_type is ref.geometry_type
|
|
||||||
assert restored.geometry_id == ref.geometry_id
|
|
||||||
np.testing.assert_array_almost_equal(restored.origin, ref.origin)
|
|
||||||
assert restored.direction is not None
|
|
||||||
np.testing.assert_array_almost_equal(restored.direction, ref.direction)
|
|
||||||
|
|
||||||
def test_to_dict_with_none_direction(self) -> None:
|
|
||||||
ref = GeometryRef(
|
|
||||||
body_id=2,
|
|
||||||
geometry_type=GeometryType.POINT,
|
|
||||||
geometry_id="Point002",
|
|
||||||
origin=np.array([5.0, 6.0, 7.0]),
|
|
||||||
)
|
|
||||||
d = ref.to_dict()
|
|
||||||
assert d["direction"] is None
|
|
||||||
restored = GeometryRef.from_dict(d)
|
|
||||||
assert restored.direction is None
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Mate
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMate:
|
|
||||||
"""Mate dataclass."""
|
|
||||||
|
|
||||||
def test_construction(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
assert m.mate_id == 0
|
|
||||||
assert m.mate_type is MateType.COINCIDENT
|
|
||||||
|
|
||||||
def test_value_default_zero(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
assert m.value == 0.0
|
|
||||||
|
|
||||||
def test_tolerance_default(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
assert m.tolerance == 1e-6
|
|
||||||
|
|
||||||
def test_to_dict_round_trip(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.AXIS, origin=np.array([1.0, 0.0, 0.0]))
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS, origin=np.array([2.0, 0.0, 0.0]))
|
|
||||||
m = Mate(
|
|
||||||
mate_id=5,
|
|
||||||
mate_type=MateType.CONCENTRIC,
|
|
||||||
ref_a=ref_a,
|
|
||||||
ref_b=ref_b,
|
|
||||||
value=0.0,
|
|
||||||
tolerance=1e-8,
|
|
||||||
)
|
|
||||||
d = m.to_dict()
|
|
||||||
restored = Mate.from_dict(d)
|
|
||||||
assert restored.mate_id == m.mate_id
|
|
||||||
assert restored.mate_type is m.mate_type
|
|
||||||
assert restored.ref_a.body_id == m.ref_a.body_id
|
|
||||||
assert restored.ref_b.body_id == m.ref_b.body_id
|
|
||||||
assert restored.value == m.value
|
|
||||||
assert restored.tolerance == m.tolerance
|
|
||||||
|
|
||||||
def test_from_dict_missing_optional(self) -> None:
|
|
||||||
d = {
|
|
||||||
"mate_id": 1,
|
|
||||||
"mate_type": "DISTANCE",
|
|
||||||
"ref_a": _make_ref(0, GeometryType.POINT).to_dict(),
|
|
||||||
"ref_b": _make_ref(1, GeometryType.POINT).to_dict(),
|
|
||||||
}
|
|
||||||
m = Mate.from_dict(d)
|
|
||||||
assert m.value == 0.0
|
|
||||||
assert m.tolerance == 1e-6
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# dof_removed
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestDofRemoved:
|
|
||||||
"""Context-dependent DOF removal counts."""
|
|
||||||
|
|
||||||
def test_coincident_face_face(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
|
|
||||||
|
|
||||||
def test_coincident_point_point(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.POINT)
|
|
||||||
ref_b = _make_ref(1, GeometryType.POINT)
|
|
||||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 3
|
|
||||||
|
|
||||||
def test_coincident_edge_edge(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.EDGE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.EDGE)
|
|
||||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 2
|
|
||||||
|
|
||||||
def test_coincident_face_point(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.POINT)
|
|
||||||
assert dof_removed(MateType.COINCIDENT, ref_a, ref_b) == 1
|
|
||||||
|
|
||||||
def test_concentric_axis_axis(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
|
||||||
assert dof_removed(MateType.CONCENTRIC, ref_a, ref_b) == 2
|
|
||||||
|
|
||||||
def test_lock_any(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.POINT)
|
|
||||||
assert dof_removed(MateType.LOCK, ref_a, ref_b) == 6
|
|
||||||
|
|
||||||
def test_distance_any(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.POINT)
|
|
||||||
ref_b = _make_ref(1, GeometryType.EDGE)
|
|
||||||
assert dof_removed(MateType.DISTANCE, ref_a, ref_b) == 1
|
|
||||||
|
|
||||||
def test_unknown_combo_uses_default(self) -> None:
|
|
||||||
"""Unlisted geometry combos fall back to default_dof."""
|
|
||||||
ref_a = _make_ref(0, GeometryType.EDGE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.POINT)
|
|
||||||
result = dof_removed(MateType.COINCIDENT, ref_a, ref_b)
|
|
||||||
assert result == MateType.COINCIDENT.default_dof
|
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Mate.validate
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
|
|
||||||
class TestMateValidation:
|
|
||||||
"""Mate.validate() compatibility checks."""
|
|
||||||
|
|
||||||
def test_valid_concentric(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
m.validate() # should not raise
|
|
||||||
|
|
||||||
def test_invalid_concentric_face(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
with pytest.raises(ValueError, match="CONCENTRIC"):
|
|
||||||
m.validate()
|
|
||||||
|
|
||||||
def test_valid_coincident_face_face(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
m.validate() # should not raise
|
|
||||||
|
|
||||||
def test_invalid_self_mate(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.FACE)
|
|
||||||
ref_b = _make_ref(0, GeometryType.FACE, geometry_id="Face002")
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.COINCIDENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
with pytest.raises(ValueError, match="Self-mate"):
|
|
||||||
m.validate()
|
|
||||||
|
|
||||||
def test_invalid_parallel_point(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.POINT)
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.PARALLEL, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
with pytest.raises(ValueError, match="PARALLEL"):
|
|
||||||
m.validate()
|
|
||||||
|
|
||||||
def test_invalid_tangent_axis(self) -> None:
|
|
||||||
ref_a = _make_ref(0, GeometryType.AXIS)
|
|
||||||
ref_b = _make_ref(1, GeometryType.FACE)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.TANGENT, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
with pytest.raises(ValueError, match="TANGENT"):
|
|
||||||
m.validate()
|
|
||||||
|
|
||||||
def test_missing_direction_for_axis(self) -> None:
|
|
||||||
ref_a = GeometryRef(
|
|
||||||
body_id=0,
|
|
||||||
geometry_type=GeometryType.AXIS,
|
|
||||||
geometry_id="Axis001",
|
|
||||||
origin=np.zeros(3),
|
|
||||||
direction=None, # missing!
|
|
||||||
)
|
|
||||||
ref_b = _make_ref(1, GeometryType.AXIS)
|
|
||||||
m = Mate(mate_id=0, mate_type=MateType.CONCENTRIC, ref_a=ref_a, ref_b=ref_b)
|
|
||||||
with pytest.raises(ValueError, match="direction"):
|
|
||||||
m.validate()
|
|
||||||
188
tests/models/test_assembly_gnn.py
Normal file
188
tests/models/test_assembly_gnn.py
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
"""Tests for solver.models.assembly_gnn -- main model wiring."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from solver.models.assembly_gnn import AssemblyGNN
|
||||||
|
|
||||||
|
|
||||||
|
def _default_heads_config(dof_tracking: bool = False) -> dict:
|
||||||
|
return {
|
||||||
|
"edge_classification": {"enabled": True, "hidden_dim": 64},
|
||||||
|
"graph_classification": {"enabled": True, "num_classes": 4},
|
||||||
|
"joint_type": {"enabled": True, "num_classes": 11},
|
||||||
|
"dof_regression": {"enabled": True},
|
||||||
|
"dof_tracking": {"enabled": dof_tracking},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _random_graph(
|
||||||
|
n_nodes: int = 8,
|
||||||
|
n_edges: int = 16,
|
||||||
|
batch_size: int = 2,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
x = torch.randn(n_nodes, 22)
|
||||||
|
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||||
|
edge_attr = torch.randn(n_edges, 22)
|
||||||
|
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
||||||
|
if len(batch) < n_nodes:
|
||||||
|
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
||||||
|
return x, edge_index, edge_attr, batch
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyGNNGIN:
|
||||||
|
"""AssemblyGNN with GIN encoder."""
|
||||||
|
|
||||||
|
def test_forward_all_heads(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert "edge_pred" in preds
|
||||||
|
assert "graph_pred" in preds
|
||||||
|
assert "joint_type_pred" in preds
|
||||||
|
assert "dof_pred" in preds
|
||||||
|
|
||||||
|
def test_output_shapes(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20, batch_size=3)
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert preds["edge_pred"].shape == (20, 1)
|
||||||
|
assert preds["graph_pred"].shape == (3, 4)
|
||||||
|
assert preds["joint_type_pred"].shape == (20, 11)
|
||||||
|
assert preds["dof_pred"].shape == (3, 1)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
x.requires_grad_(True)
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
total = sum(p.sum() for p in preds.values())
|
||||||
|
total.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
|
||||||
|
def test_no_heads_returns_empty(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||||
|
heads_config={},
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert len(preds) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyGNNGAT:
|
||||||
|
"""AssemblyGNN with GAT encoder."""
|
||||||
|
|
||||||
|
def test_forward_all_heads(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gat",
|
||||||
|
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
||||||
|
heads_config=_default_heads_config(dof_tracking=True),
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert "edge_pred" in preds
|
||||||
|
assert "graph_pred" in preds
|
||||||
|
assert "joint_type_pred" in preds
|
||||||
|
assert "dof_pred" in preds
|
||||||
|
assert "body_dof_pred" in preds
|
||||||
|
|
||||||
|
def test_body_dof_shape(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gat",
|
||||||
|
encoder_config={"hidden_dim": 64, "num_layers": 2, "num_heads": 4},
|
||||||
|
heads_config=_default_heads_config(dof_tracking=True),
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=20)
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert preds["body_dof_pred"].shape == (10, 2)
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyGNNEdgeCases:
|
||||||
|
"""Edge cases and error handling."""
|
||||||
|
|
||||||
|
def test_unknown_encoder_raises(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="Unknown encoder"):
|
||||||
|
AssemblyGNN(encoder_type="transformer")
|
||||||
|
|
||||||
|
def test_selective_heads(self) -> None:
|
||||||
|
"""Only enabled heads produce output."""
|
||||||
|
config = {
|
||||||
|
"edge_classification": {"enabled": True},
|
||||||
|
"graph_classification": {"enabled": False},
|
||||||
|
"joint_type": {"enabled": True, "num_classes": 11},
|
||||||
|
}
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||||
|
heads_config=config,
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
preds = model(x, ei, ea, batch)
|
||||||
|
assert "edge_pred" in preds
|
||||||
|
assert "joint_type_pred" in preds
|
||||||
|
assert "graph_pred" not in preds
|
||||||
|
assert "dof_pred" not in preds
|
||||||
|
|
||||||
|
def test_no_batch_single_graph(self) -> None:
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
x = torch.randn(6, 22)
|
||||||
|
ei = torch.randint(0, 6, (2, 10))
|
||||||
|
ea = torch.randn(10, 22)
|
||||||
|
preds = model(x, ei, ea)
|
||||||
|
assert preds["graph_pred"].shape == (1, 4)
|
||||||
|
assert preds["dof_pred"].shape == (1, 1)
|
||||||
|
|
||||||
|
def test_parameter_count_reasonable(self) -> None:
|
||||||
|
"""Sanity check that model has learnable parameters."""
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 64, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
n_params = sum(p.numel() for p in model.parameters())
|
||||||
|
assert n_params > 1000 # non-trivial model
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyGNNEndToEnd:
|
||||||
|
"""End-to-end test with datagen pipeline."""
|
||||||
|
|
||||||
|
def test_datagen_to_model(self) -> None:
|
||||||
|
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||||
|
from solver.models.graph_conv import assembly_to_pyg
|
||||||
|
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(batch_size=2, complexity_tier="simple")
|
||||||
|
|
||||||
|
model = AssemblyGNN(
|
||||||
|
encoder_type="gin",
|
||||||
|
encoder_config={"hidden_dim": 32, "num_layers": 2},
|
||||||
|
heads_config=_default_heads_config(),
|
||||||
|
)
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
for ex in batch:
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
with torch.no_grad():
|
||||||
|
preds = model(data.x, data.edge_index, data.edge_attr)
|
||||||
|
assert "edge_pred" in preds
|
||||||
|
assert preds["edge_pred"].shape[0] == data.edge_index.shape[1]
|
||||||
183
tests/models/test_encoder.py
Normal file
183
tests/models/test_encoder.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Tests for solver.models.encoder -- GIN and GAT graph encoders."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from solver.models.encoder import GATEncoder, GINEncoder
|
||||||
|
|
||||||
|
|
||||||
|
def _random_graph(
|
||||||
|
n_nodes: int = 8,
|
||||||
|
n_edges: int = 20,
|
||||||
|
node_dim: int = 22,
|
||||||
|
edge_dim: int = 22,
|
||||||
|
batch_size: int = 2,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
|
"""Create a random graph for testing."""
|
||||||
|
x = torch.randn(n_nodes, node_dim)
|
||||||
|
edge_index = torch.randint(0, n_nodes, (2, n_edges))
|
||||||
|
edge_attr = torch.randn(n_edges, edge_dim)
|
||||||
|
# Assign nodes to batches roughly evenly.
|
||||||
|
batch = torch.arange(batch_size).repeat_interleave(n_nodes // batch_size)
|
||||||
|
# Handle remainder nodes.
|
||||||
|
if len(batch) < n_nodes:
|
||||||
|
batch = torch.cat([batch, torch.full((n_nodes - len(batch),), batch_size - 1)])
|
||||||
|
return x, edge_index, edge_attr, batch
|
||||||
|
|
||||||
|
|
||||||
|
class TestGINEncoder:
|
||||||
|
"""GINEncoder shape and gradient tests."""
|
||||||
|
|
||||||
|
def test_output_shapes(self) -> None:
|
||||||
|
enc = GINEncoder(node_features_dim=22, edge_features_dim=22, hidden_dim=64, num_layers=2)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape == (10, 64)
|
||||||
|
assert edge_emb.shape == (16, 64)
|
||||||
|
assert graph_emb.shape == (3, 64)
|
||||||
|
|
||||||
|
def test_hidden_dim_property(self) -> None:
|
||||||
|
enc = GINEncoder(hidden_dim=128)
|
||||||
|
assert enc.hidden_dim == 128
|
||||||
|
|
||||||
|
def test_default_dimensions(self) -> None:
|
||||||
|
enc = GINEncoder()
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape[1] == 128
|
||||||
|
assert edge_emb.shape[1] == 128
|
||||||
|
assert graph_emb.shape[1] == 128
|
||||||
|
|
||||||
|
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||||
|
enc = GINEncoder(hidden_dim=64, num_layers=2)
|
||||||
|
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||||
|
assert graph_emb.shape == (1, 64)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||||
|
x.requires_grad_(True)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
loss = graph_emb.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.abs().sum() > 0
|
||||||
|
|
||||||
|
def test_zero_edges(self) -> None:
|
||||||
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||||
|
x = torch.randn(4, 22)
|
||||||
|
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||||
|
ea = torch.zeros(0, 22)
|
||||||
|
batch = torch.tensor([0, 0, 1, 1])
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape == (4, 32)
|
||||||
|
assert edge_emb.shape == (0, 32)
|
||||||
|
assert graph_emb.shape == (2, 32)
|
||||||
|
|
||||||
|
def test_single_node(self) -> None:
|
||||||
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||||
|
# Train with a small batch first to populate BN running stats.
|
||||||
|
x_train = torch.randn(4, 22)
|
||||||
|
ei_train = torch.zeros(2, 0, dtype=torch.long)
|
||||||
|
ea_train = torch.zeros(0, 22)
|
||||||
|
batch_train = torch.tensor([0, 0, 1, 1])
|
||||||
|
enc.train()
|
||||||
|
enc(x_train, ei_train, ea_train, batch_train)
|
||||||
|
# Now test single node in eval mode (BN uses running stats).
|
||||||
|
enc.eval()
|
||||||
|
x = torch.randn(1, 22)
|
||||||
|
ei = torch.zeros(2, 0, dtype=torch.long)
|
||||||
|
ea = torch.zeros(0, 22)
|
||||||
|
with torch.no_grad():
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea)
|
||||||
|
assert node_emb.shape == (1, 32)
|
||||||
|
assert graph_emb.shape == (1, 32)
|
||||||
|
|
||||||
|
def test_eval_mode(self) -> None:
|
||||||
|
"""Encoder works in eval mode (BatchNorm uses running stats)."""
|
||||||
|
enc = GINEncoder(hidden_dim=32, num_layers=2)
|
||||||
|
# Forward pass in train mode to populate BN stats.
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||||
|
enc.train()
|
||||||
|
enc(x, ei, ea, batch)
|
||||||
|
# Switch to eval.
|
||||||
|
enc.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape[1] == 32
|
||||||
|
|
||||||
|
|
||||||
|
class TestGATEncoder:
|
||||||
|
"""GATEncoder shape and gradient tests."""
|
||||||
|
|
||||||
|
def test_output_shapes(self) -> None:
|
||||||
|
enc = GATEncoder(
|
||||||
|
node_features_dim=22,
|
||||||
|
edge_features_dim=22,
|
||||||
|
hidden_dim=64,
|
||||||
|
num_layers=2,
|
||||||
|
num_heads=4,
|
||||||
|
)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=10, n_edges=16, batch_size=3)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape == (10, 64)
|
||||||
|
assert edge_emb.shape == (16, 64)
|
||||||
|
assert graph_emb.shape == (3, 64)
|
||||||
|
|
||||||
|
def test_hidden_dim_property(self) -> None:
|
||||||
|
enc = GATEncoder(hidden_dim=256, num_heads=8)
|
||||||
|
assert enc.hidden_dim == 256
|
||||||
|
|
||||||
|
def test_default_dimensions(self) -> None:
|
||||||
|
enc = GATEncoder()
|
||||||
|
x, ei, ea, batch = _random_graph()
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape[1] == 256
|
||||||
|
assert edge_emb.shape[1] == 256
|
||||||
|
assert graph_emb.shape[1] == 256
|
||||||
|
|
||||||
|
def test_no_batch_defaults_to_single_graph(self) -> None:
|
||||||
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||||
|
x, ei, ea, _ = _random_graph(n_nodes=6, n_edges=10)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch=None)
|
||||||
|
assert graph_emb.shape == (1, 64)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||||
|
x.requires_grad_(True)
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
loss = graph_emb.sum()
|
||||||
|
loss.backward()
|
||||||
|
assert x.grad is not None
|
||||||
|
assert x.grad.abs().sum() > 0
|
||||||
|
|
||||||
|
def test_residual_connection(self) -> None:
|
||||||
|
"""With residual=True, output should differ from residual=False."""
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
enc_res = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=True)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
enc_no = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4, residual=False)
|
||||||
|
with torch.no_grad():
|
||||||
|
n1, _, _ = enc_res(x, ei, ea, batch)
|
||||||
|
n2, _, _ = enc_no(x, ei, ea, batch)
|
||||||
|
# Outputs should generally differ (unless by very unlikely coincidence).
|
||||||
|
assert not torch.allclose(n1, n2, atol=1e-4)
|
||||||
|
|
||||||
|
def test_hidden_dim_must_divide_heads(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="divisible"):
|
||||||
|
GATEncoder(hidden_dim=100, num_heads=8)
|
||||||
|
|
||||||
|
def test_eval_mode(self) -> None:
|
||||||
|
enc = GATEncoder(hidden_dim=64, num_layers=2, num_heads=4)
|
||||||
|
x, ei, ea, batch = _random_graph(n_nodes=8, n_edges=12)
|
||||||
|
enc.train()
|
||||||
|
enc(x, ei, ea, batch)
|
||||||
|
enc.eval()
|
||||||
|
with torch.no_grad():
|
||||||
|
node_emb, edge_emb, graph_emb = enc(x, ei, ea, batch)
|
||||||
|
assert node_emb.shape[1] == 64
|
||||||
110
tests/models/test_factory.py
Normal file
110
tests/models/test_factory.py
Normal file
@@ -0,0 +1,110 @@
|
|||||||
|
"""Tests for solver.models.factory -- model and loss construction from config."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from solver.models.assembly_gnn import AssemblyGNN
|
||||||
|
from solver.models.factory import build_loss, build_model
|
||||||
|
from solver.models.losses import MultiTaskLoss
|
||||||
|
|
||||||
|
|
||||||
|
def _load_yaml(path: str) -> dict:
|
||||||
|
with open(path) as f:
|
||||||
|
return yaml.safe_load(f)
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildModel:
|
||||||
|
"""build_model constructs AssemblyGNN from config."""
|
||||||
|
|
||||||
|
def test_baseline_config(self) -> None:
|
||||||
|
config = _load_yaml("configs/model/baseline.yaml")
|
||||||
|
model = build_model(config)
|
||||||
|
assert isinstance(model, AssemblyGNN)
|
||||||
|
assert model.encoder.hidden_dim == 128
|
||||||
|
|
||||||
|
def test_gat_config(self) -> None:
|
||||||
|
config = _load_yaml("configs/model/gat.yaml")
|
||||||
|
model = build_model(config)
|
||||||
|
assert isinstance(model, AssemblyGNN)
|
||||||
|
assert model.encoder.hidden_dim == 256
|
||||||
|
|
||||||
|
def test_baseline_heads_present(self) -> None:
|
||||||
|
config = _load_yaml("configs/model/baseline.yaml")
|
||||||
|
model = build_model(config)
|
||||||
|
assert "edge_pred" in model.heads
|
||||||
|
assert "graph_pred" in model.heads
|
||||||
|
assert "joint_type_pred" in model.heads
|
||||||
|
assert "dof_pred" in model.heads
|
||||||
|
|
||||||
|
def test_gat_has_dof_tracking(self) -> None:
|
||||||
|
config = _load_yaml("configs/model/gat.yaml")
|
||||||
|
model = build_model(config)
|
||||||
|
assert "body_dof_pred" in model.heads
|
||||||
|
|
||||||
|
def test_baseline_no_dof_tracking(self) -> None:
|
||||||
|
config = _load_yaml("configs/model/baseline.yaml")
|
||||||
|
model = build_model(config)
|
||||||
|
assert "body_dof_pred" not in model.heads
|
||||||
|
|
||||||
|
def test_minimal_config(self) -> None:
|
||||||
|
config = {"architecture": "gin"}
|
||||||
|
model = build_model(config)
|
||||||
|
assert isinstance(model, AssemblyGNN)
|
||||||
|
# No heads enabled.
|
||||||
|
assert len(model.heads) == 0
|
||||||
|
|
||||||
|
def test_custom_config(self) -> None:
|
||||||
|
config = {
|
||||||
|
"architecture": "gin",
|
||||||
|
"encoder": {"hidden_dim": 64, "num_layers": 2},
|
||||||
|
"heads": {
|
||||||
|
"edge_classification": {"enabled": True},
|
||||||
|
"graph_classification": {"enabled": True, "num_classes": 4},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
model = build_model(config)
|
||||||
|
assert model.encoder.hidden_dim == 64
|
||||||
|
assert "edge_pred" in model.heads
|
||||||
|
assert "graph_pred" in model.heads
|
||||||
|
assert "joint_type_pred" not in model.heads
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildLoss:
|
||||||
|
"""build_loss constructs MultiTaskLoss from training config."""
|
||||||
|
|
||||||
|
def test_pretrain_config(self) -> None:
|
||||||
|
config = _load_yaml("configs/training/pretrain.yaml")
|
||||||
|
loss_fn = build_loss(config)
|
||||||
|
assert isinstance(loss_fn, MultiTaskLoss)
|
||||||
|
|
||||||
|
def test_weights_from_config(self) -> None:
|
||||||
|
config = _load_yaml("configs/training/pretrain.yaml")
|
||||||
|
loss_fn = build_loss(config)
|
||||||
|
assert loss_fn.weights["edge"] == 1.0
|
||||||
|
assert loss_fn.weights["graph"] == 0.5
|
||||||
|
assert loss_fn.weights["joint_type"] == 0.3
|
||||||
|
assert loss_fn.weights["dof"] == 0.2
|
||||||
|
|
||||||
|
def test_redundant_penalty_from_config(self) -> None:
|
||||||
|
config = _load_yaml("configs/training/pretrain.yaml")
|
||||||
|
loss_fn = build_loss(config)
|
||||||
|
assert loss_fn.redundant_penalty == 2.0
|
||||||
|
|
||||||
|
def test_empty_config_uses_defaults(self) -> None:
|
||||||
|
loss_fn = build_loss({})
|
||||||
|
assert isinstance(loss_fn, MultiTaskLoss)
|
||||||
|
assert loss_fn.weights["edge"] == 1.0
|
||||||
|
|
||||||
|
def test_custom_weights(self) -> None:
|
||||||
|
config = {
|
||||||
|
"loss": {
|
||||||
|
"edge_weight": 2.0,
|
||||||
|
"graph_weight": 1.0,
|
||||||
|
"redundant_penalty": 5.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
loss_fn = build_loss(config)
|
||||||
|
assert loss_fn.weights["edge"] == 2.0
|
||||||
|
assert loss_fn.weights["graph"] == 1.0
|
||||||
|
assert loss_fn.redundant_penalty == 5.0
|
||||||
144
tests/models/test_graph_conv.py
Normal file
144
tests/models/test_graph_conv.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Tests for solver.models.graph_conv -- assembly to PyG conversion."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||||
|
from solver.datagen.types import JointType
|
||||||
|
from solver.models.graph_conv import ASSEMBLY_CLASSES, JOINT_TYPE_NAMES, assembly_to_pyg
|
||||||
|
|
||||||
|
|
||||||
|
def _make_example(n_bodies: int = 4, grounded: bool = True, seed: int = 0) -> dict:
|
||||||
|
"""Generate a single training example via the datagen pipeline."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=seed)
|
||||||
|
batch = gen.generate_training_batch(batch_size=1, n_bodies_range=(n_bodies, n_bodies + 1))
|
||||||
|
return batch[0]
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyToPyg:
|
||||||
|
"""assembly_to_pyg converts datagen dicts to PyG Data correctly."""
|
||||||
|
|
||||||
|
def test_node_feature_shape(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=5)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
assert data.x.shape == (5, 22)
|
||||||
|
|
||||||
|
def test_edge_feature_shape(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
n_joints = ex["n_joints"]
|
||||||
|
assert data.edge_attr.shape == (n_joints * 2, 22)
|
||||||
|
|
||||||
|
def test_edge_index_bidirectional(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
ei = data.edge_index
|
||||||
|
# Each joint produces 2 directed edges: a->b and b->a.
|
||||||
|
for i in range(0, ei.size(1), 2):
|
||||||
|
assert ei[0, i].item() == ei[1, i + 1].item()
|
||||||
|
assert ei[1, i].item() == ei[0, i + 1].item()
|
||||||
|
|
||||||
|
def test_edge_index_shape(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
n_joints = ex["n_joints"]
|
||||||
|
assert data.edge_index.shape == (2, n_joints * 2)
|
||||||
|
|
||||||
|
def test_node_features_centered(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=5)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
# Positions (dims 0-2) should be centered (mean ~0).
|
||||||
|
pos = data.x[:, :3]
|
||||||
|
assert pos.mean(dim=0).abs().max().item() < 1e-5
|
||||||
|
|
||||||
|
def test_grounded_flag_set(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4, grounded=True)
|
||||||
|
ex["grounded"] = True
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
assert data.x[0, 12].item() == 1.0
|
||||||
|
|
||||||
|
def test_ungrounded_flag_clear(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
ex["grounded"] = False
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
assert (data.x[:, 12] == 0.0).all()
|
||||||
|
|
||||||
|
def test_edge_type_one_hot_valid(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=5)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
if data.edge_attr.size(0) > 0:
|
||||||
|
onehot = data.edge_attr[:, :11]
|
||||||
|
# Each row should have exactly one 1.0.
|
||||||
|
assert (onehot.sum(dim=1) == 1.0).all()
|
||||||
|
|
||||||
|
def test_labels_present_when_requested(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex, include_labels=True)
|
||||||
|
assert hasattr(data, "y_edge")
|
||||||
|
assert hasattr(data, "y_graph")
|
||||||
|
assert hasattr(data, "y_joint_type")
|
||||||
|
assert hasattr(data, "y_dof")
|
||||||
|
assert hasattr(data, "y_body_dof")
|
||||||
|
|
||||||
|
def test_labels_absent_when_not_requested(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex, include_labels=False)
|
||||||
|
assert not hasattr(data, "y_edge")
|
||||||
|
assert not hasattr(data, "y_graph")
|
||||||
|
|
||||||
|
def test_graph_classification_label_mapping(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
cls = ex["assembly_classification"]
|
||||||
|
expected = ASSEMBLY_CLASSES[cls]
|
||||||
|
assert data.y_graph.item() == expected
|
||||||
|
|
||||||
|
def test_body_dof_shape(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=5)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
assert data.y_body_dof.shape == (5, 2)
|
||||||
|
|
||||||
|
def test_edge_labels_binary(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=5)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
if data.y_edge.numel() > 0:
|
||||||
|
assert ((data.y_edge == 0.0) | (data.y_edge == 1.0)).all()
|
||||||
|
|
||||||
|
def test_dof_removed_normalized(self) -> None:
|
||||||
|
ex = _make_example(n_bodies=4)
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
if data.edge_attr.size(0) > 0:
|
||||||
|
dof_norm = data.edge_attr[:, 21]
|
||||||
|
assert (dof_norm >= 0.0).all()
|
||||||
|
assert (dof_norm <= 1.0).all()
|
||||||
|
|
||||||
|
def test_roundtrip_with_generator(self) -> None:
|
||||||
|
"""Generate a real example and convert -- no crash."""
|
||||||
|
gen = SyntheticAssemblyGenerator(seed=42)
|
||||||
|
batch = gen.generate_training_batch(batch_size=5, complexity_tier="simple")
|
||||||
|
for ex in batch:
|
||||||
|
data = assembly_to_pyg(ex)
|
||||||
|
assert data.x.shape[1] == 22
|
||||||
|
assert data.edge_attr.shape[1] == 22
|
||||||
|
|
||||||
|
|
||||||
|
class TestAssemblyClasses:
|
||||||
|
"""ASSEMBLY_CLASSES covers all classifications."""
|
||||||
|
|
||||||
|
def test_four_classes(self) -> None:
|
||||||
|
assert len(ASSEMBLY_CLASSES) == 4
|
||||||
|
|
||||||
|
def test_values_are_0_to_3(self) -> None:
|
||||||
|
assert set(ASSEMBLY_CLASSES.values()) == {0, 1, 2, 3}
|
||||||
|
|
||||||
|
|
||||||
|
class TestJointTypeNames:
|
||||||
|
"""JOINT_TYPE_NAMES matches the JointType enum."""
|
||||||
|
|
||||||
|
def test_length_matches_enum(self) -> None:
|
||||||
|
assert len(JOINT_TYPE_NAMES) == len(JointType)
|
||||||
|
|
||||||
|
def test_order_matches_ordinal(self) -> None:
|
||||||
|
for i, name in enumerate(JOINT_TYPE_NAMES):
|
||||||
|
assert JointType[name].value[0] == i
|
||||||
182
tests/models/test_heads.py
Normal file
182
tests/models/test_heads.py
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
"""Tests for solver.models.heads -- task-specific prediction heads."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from solver.models.heads import (
|
||||||
|
DOFRegressionHead,
|
||||||
|
DOFTrackingHead,
|
||||||
|
EdgeClassificationHead,
|
||||||
|
GraphClassificationHead,
|
||||||
|
JointTypeHead,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEdgeClassificationHead:
|
||||||
|
"""EdgeClassificationHead produces correct shape and gradients."""
|
||||||
|
|
||||||
|
def test_output_shape(self) -> None:
|
||||||
|
head = EdgeClassificationHead(hidden_dim=128)
|
||||||
|
edge_emb = torch.randn(20, 128)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (20, 1)
|
||||||
|
|
||||||
|
def test_output_shape_small(self) -> None:
|
||||||
|
head = EdgeClassificationHead(hidden_dim=32)
|
||||||
|
edge_emb = torch.randn(5, 32)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (5, 1)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
head = EdgeClassificationHead(hidden_dim=64)
|
||||||
|
edge_emb = torch.randn(10, 64, requires_grad=True)
|
||||||
|
out = head(edge_emb)
|
||||||
|
out.sum().backward()
|
||||||
|
assert edge_emb.grad is not None
|
||||||
|
assert edge_emb.grad.abs().sum() > 0
|
||||||
|
|
||||||
|
def test_zero_edges(self) -> None:
|
||||||
|
head = EdgeClassificationHead(hidden_dim=64)
|
||||||
|
edge_emb = torch.zeros(0, 64)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (0, 1)
|
||||||
|
|
||||||
|
def test_output_is_logits(self) -> None:
|
||||||
|
"""Output should be unbounded logits (not probabilities)."""
|
||||||
|
head = EdgeClassificationHead(hidden_dim=64)
|
||||||
|
torch.manual_seed(42)
|
||||||
|
edge_emb = torch.randn(100, 64)
|
||||||
|
out = head(edge_emb)
|
||||||
|
# Logits can be negative.
|
||||||
|
assert out.min().item() < 0 or out.max().item() > 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestGraphClassificationHead:
|
||||||
|
"""GraphClassificationHead produces correct shape and gradients."""
|
||||||
|
|
||||||
|
def test_output_shape(self) -> None:
|
||||||
|
head = GraphClassificationHead(hidden_dim=128, num_classes=4)
|
||||||
|
graph_emb = torch.randn(3, 128)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert out.shape == (3, 4)
|
||||||
|
|
||||||
|
def test_custom_num_classes(self) -> None:
|
||||||
|
head = GraphClassificationHead(hidden_dim=64, num_classes=8)
|
||||||
|
graph_emb = torch.randn(2, 64)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert out.shape == (2, 8)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
head = GraphClassificationHead(hidden_dim=64)
|
||||||
|
graph_emb = torch.randn(2, 64, requires_grad=True)
|
||||||
|
out = head(graph_emb)
|
||||||
|
out.sum().backward()
|
||||||
|
assert graph_emb.grad is not None
|
||||||
|
|
||||||
|
def test_single_graph(self) -> None:
|
||||||
|
head = GraphClassificationHead(hidden_dim=128)
|
||||||
|
graph_emb = torch.randn(1, 128)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert out.shape == (1, 4)
|
||||||
|
|
||||||
|
|
||||||
|
class TestJointTypeHead:
|
||||||
|
"""JointTypeHead produces correct shape and gradients."""
|
||||||
|
|
||||||
|
def test_output_shape(self) -> None:
|
||||||
|
head = JointTypeHead(hidden_dim=128, num_classes=11)
|
||||||
|
edge_emb = torch.randn(20, 128)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (20, 11)
|
||||||
|
|
||||||
|
def test_custom_classes(self) -> None:
|
||||||
|
head = JointTypeHead(hidden_dim=64, num_classes=7)
|
||||||
|
edge_emb = torch.randn(10, 64)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (10, 7)
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
head = JointTypeHead(hidden_dim=64)
|
||||||
|
edge_emb = torch.randn(10, 64, requires_grad=True)
|
||||||
|
out = head(edge_emb)
|
||||||
|
out.sum().backward()
|
||||||
|
assert edge_emb.grad is not None
|
||||||
|
|
||||||
|
def test_zero_edges(self) -> None:
|
||||||
|
head = JointTypeHead(hidden_dim=64)
|
||||||
|
edge_emb = torch.zeros(0, 64)
|
||||||
|
out = head(edge_emb)
|
||||||
|
assert out.shape == (0, 11)
|
||||||
|
|
||||||
|
|
||||||
|
class TestDOFRegressionHead:
|
||||||
|
"""DOFRegressionHead produces correct shape and non-negative output."""
|
||||||
|
|
||||||
|
def test_output_shape(self) -> None:
|
||||||
|
head = DOFRegressionHead(hidden_dim=128)
|
||||||
|
graph_emb = torch.randn(3, 128)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert out.shape == (3, 1)
|
||||||
|
|
||||||
|
def test_output_non_negative(self) -> None:
|
||||||
|
"""Softplus ensures non-negative output."""
|
||||||
|
head = DOFRegressionHead(hidden_dim=64)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
graph_emb = torch.randn(50, 64)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert (out >= 0).all()
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
head = DOFRegressionHead(hidden_dim=64)
|
||||||
|
graph_emb = torch.randn(2, 64, requires_grad=True)
|
||||||
|
out = head(graph_emb)
|
||||||
|
out.sum().backward()
|
||||||
|
assert graph_emb.grad is not None
|
||||||
|
|
||||||
|
def test_single_graph(self) -> None:
|
||||||
|
head = DOFRegressionHead(hidden_dim=32)
|
||||||
|
graph_emb = torch.randn(1, 32)
|
||||||
|
out = head(graph_emb)
|
||||||
|
assert out.shape == (1, 1)
|
||||||
|
assert out.item() >= 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDOFTrackingHead:
|
||||||
|
"""DOFTrackingHead produces correct shape and non-negative output."""
|
||||||
|
|
||||||
|
def test_output_shape(self) -> None:
|
||||||
|
head = DOFTrackingHead(hidden_dim=128)
|
||||||
|
node_emb = torch.randn(10, 128)
|
||||||
|
out = head(node_emb)
|
||||||
|
assert out.shape == (10, 2)
|
||||||
|
|
||||||
|
def test_output_non_negative(self) -> None:
|
||||||
|
"""Softplus ensures non-negative output."""
|
||||||
|
head = DOFTrackingHead(hidden_dim=64)
|
||||||
|
torch.manual_seed(0)
|
||||||
|
node_emb = torch.randn(50, 64)
|
||||||
|
out = head(node_emb)
|
||||||
|
assert (out >= 0).all()
|
||||||
|
|
||||||
|
def test_gradients_flow(self) -> None:
|
||||||
|
head = DOFTrackingHead(hidden_dim=64)
|
||||||
|
node_emb = torch.randn(10, 64, requires_grad=True)
|
||||||
|
out = head(node_emb)
|
||||||
|
out.sum().backward()
|
||||||
|
assert node_emb.grad is not None
|
||||||
|
|
||||||
|
def test_single_node(self) -> None:
|
||||||
|
head = DOFTrackingHead(hidden_dim=32)
|
||||||
|
node_emb = torch.randn(1, 32)
|
||||||
|
out = head(node_emb)
|
||||||
|
assert out.shape == (1, 2)
|
||||||
|
assert (out >= 0).all()
|
||||||
|
|
||||||
|
def test_two_columns_independent(self) -> None:
|
||||||
|
"""Translational and rotational DOF are independently predicted."""
|
||||||
|
head = DOFTrackingHead(hidden_dim=64)
|
||||||
|
node_emb = torch.randn(20, 64)
|
||||||
|
out = head(node_emb)
|
||||||
|
# The two columns should generally differ.
|
||||||
|
assert not torch.allclose(out[:, 0], out[:, 1], atol=1e-6)
|
||||||
156
tests/models/test_losses.py
Normal file
156
tests/models/test_losses.py
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
"""Tests for solver.models.losses -- uncertainty-weighted multi-task loss."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from solver.models.losses import MultiTaskLoss
|
||||||
|
|
||||||
|
|
||||||
|
def _make_predictions_and_targets(
|
||||||
|
n_edges: int = 20,
|
||||||
|
batch_size: int = 3,
|
||||||
|
n_nodes: int = 10,
|
||||||
|
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
|
||||||
|
preds = {
|
||||||
|
"edge_pred": torch.randn(n_edges, 1),
|
||||||
|
"graph_pred": torch.randn(batch_size, 4),
|
||||||
|
"joint_type_pred": torch.randn(n_edges, 11),
|
||||||
|
"dof_pred": torch.rand(batch_size, 1) * 10,
|
||||||
|
"body_dof_pred": torch.rand(n_nodes, 2) * 6,
|
||||||
|
}
|
||||||
|
targets = {
|
||||||
|
"y_edge": torch.randint(0, 2, (n_edges,)).float(),
|
||||||
|
"y_graph": torch.randint(0, 4, (batch_size,)),
|
||||||
|
"y_joint_type": torch.randint(0, 11, (n_edges,)),
|
||||||
|
"y_dof": torch.rand(batch_size, 1) * 10,
|
||||||
|
"y_body_dof": torch.rand(n_nodes, 2) * 6,
|
||||||
|
}
|
||||||
|
return preds, targets
|
||||||
|
|
||||||
|
|
||||||
|
class TestMultiTaskLoss:
|
||||||
|
"""MultiTaskLoss computation tests."""
|
||||||
|
|
||||||
|
def test_returns_scalar_and_breakdown(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds, targets = _make_predictions_and_targets()
|
||||||
|
total, breakdown = loss_fn(preds, targets)
|
||||||
|
assert total.dim() == 0 # scalar
|
||||||
|
assert isinstance(breakdown, dict)
|
||||||
|
|
||||||
|
def test_all_tasks_in_breakdown(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds, targets = _make_predictions_and_targets()
|
||||||
|
_, breakdown = loss_fn(preds, targets)
|
||||||
|
assert "edge" in breakdown
|
||||||
|
assert "graph" in breakdown
|
||||||
|
assert "joint_type" in breakdown
|
||||||
|
assert "dof" in breakdown
|
||||||
|
assert "body_dof" in breakdown
|
||||||
|
|
||||||
|
def test_total_is_positive(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds, targets = _make_predictions_and_targets()
|
||||||
|
total, _ = loss_fn(preds, targets)
|
||||||
|
# With random predictions, loss should be positive.
|
||||||
|
assert total.item() > 0
|
||||||
|
|
||||||
|
def test_skips_missing_predictions(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds = {"edge_pred": torch.randn(10, 1)}
|
||||||
|
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
||||||
|
total, breakdown = loss_fn(preds, targets)
|
||||||
|
assert "edge" in breakdown
|
||||||
|
assert "graph" not in breakdown
|
||||||
|
assert "joint_type" not in breakdown
|
||||||
|
|
||||||
|
def test_skips_missing_targets(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds = {
|
||||||
|
"edge_pred": torch.randn(10, 1),
|
||||||
|
"graph_pred": torch.randn(2, 4),
|
||||||
|
}
|
||||||
|
targets = {"y_edge": torch.randint(0, 2, (10,)).float()}
|
||||||
|
_, breakdown = loss_fn(preds, targets)
|
||||||
|
assert "edge" in breakdown
|
||||||
|
assert "graph" not in breakdown
|
||||||
|
|
||||||
|
def test_gradients_flow_to_log_vars(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds, targets = _make_predictions_and_targets()
|
||||||
|
# Make preds require grad.
|
||||||
|
for k in preds:
|
||||||
|
preds[k] = preds[k].requires_grad_(True)
|
||||||
|
total, _ = loss_fn(preds, targets)
|
||||||
|
total.backward()
|
||||||
|
for name, param in loss_fn.log_vars.items():
|
||||||
|
assert param.grad is not None, f"No gradient for log_var[{name}]"
|
||||||
|
|
||||||
|
def test_gradients_flow_to_predictions(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
preds, targets = _make_predictions_and_targets()
|
||||||
|
for k in preds:
|
||||||
|
preds[k] = preds[k].requires_grad_(True)
|
||||||
|
total, _ = loss_fn(preds, targets)
|
||||||
|
total.backward()
|
||||||
|
for k, v in preds.items():
|
||||||
|
assert v.grad is not None, f"No gradient for prediction[{k}]"
|
||||||
|
|
||||||
|
def test_redundant_penalty_applies(self) -> None:
|
||||||
|
"""Redundant edges (label=0) should have higher loss contribution."""
|
||||||
|
loss_fn = MultiTaskLoss(redundant_penalty=5.0)
|
||||||
|
# All-zero predictions, label=0 (redundant).
|
||||||
|
preds_red = {"edge_pred": torch.zeros(10, 1)}
|
||||||
|
targets_red = {"y_edge": torch.zeros(10)}
|
||||||
|
total_red, _ = loss_fn(preds_red, targets_red)
|
||||||
|
|
||||||
|
loss_fn2 = MultiTaskLoss(redundant_penalty=1.0)
|
||||||
|
total_eq, _ = loss_fn2(preds_red, targets_red)
|
||||||
|
|
||||||
|
# Higher penalty should produce higher loss.
|
||||||
|
assert total_red.item() > total_eq.item()
|
||||||
|
|
||||||
|
def test_empty_predictions_returns_zero(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
total, breakdown = loss_fn({}, {})
|
||||||
|
assert total.item() == 0.0
|
||||||
|
assert len(breakdown) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestUncertaintyWeighting:
|
||||||
|
"""Test uncertainty weighting mechanism specifically."""
|
||||||
|
|
||||||
|
def test_log_vars_initialized_to_zero(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
for param in loss_fn.log_vars.values():
|
||||||
|
assert param.item() == 0.0
|
||||||
|
|
||||||
|
def test_log_vars_are_learnable(self) -> None:
|
||||||
|
loss_fn = MultiTaskLoss()
|
||||||
|
params = list(loss_fn.parameters())
|
||||||
|
log_var_params = [p for p in params if p.shape == (1,)]
|
||||||
|
assert len(log_var_params) == 5 # one per task
|
||||||
|
|
||||||
|
def test_weighting_reduces_high_loss_influence(self) -> None:
|
||||||
|
"""After a few gradient steps, log_var for a noisy task should increase."""
|
||||||
|
loss_fn = MultiTaskLoss(edge_weight=1.0, graph_weight=1.0)
|
||||||
|
optimizer = torch.optim.SGD(loss_fn.parameters(), lr=0.1)
|
||||||
|
|
||||||
|
# Simulate: edge task has high loss, graph has low.
|
||||||
|
for _ in range(20):
|
||||||
|
preds = {
|
||||||
|
"edge_pred": torch.randn(10, 1) * 10, # high variance -> high loss
|
||||||
|
"graph_pred": torch.zeros(2, 4), # near-zero loss
|
||||||
|
}
|
||||||
|
targets = {
|
||||||
|
"y_edge": torch.randint(0, 2, (10,)).float(),
|
||||||
|
"y_graph": torch.zeros(2, dtype=torch.long),
|
||||||
|
}
|
||||||
|
optimizer.zero_grad()
|
||||||
|
total, _ = loss_fn(preds, targets)
|
||||||
|
total.backward()
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
# The edge task log_var should have increased (higher uncertainty).
|
||||||
|
assert loss_fn.log_vars["edge"].item() > loss_fn.log_vars["graph"].item()
|
||||||
Reference in New Issue
Block a user