feat: ground truth labeling pipeline
- Create solver/datagen/labeling.py with label_assembly() function - Add dataclasses: ConstraintLabel, JointLabel, BodyDofLabel, AssemblyLabel, AssemblyLabels - Per-constraint labels: pebble_independent + jacobian_independent - Per-joint labels: aggregated independent/redundant/total counts - Per-body DOF: translational + rotational from nullspace projection - Assembly label: classification, total_dof, has_degeneracy flag - AssemblyLabels.to_dict() for JSON-serializable output - Integrate into generate_training_batch (adds 'labels' field) - Export AssemblyLabels and label_assembly from datagen package - Add 25 labeling tests + 1 batch structure test (184 total) Closes #9
This commit is contained in:
@@ -513,6 +513,7 @@ class TestTrainingBatch:
|
||||
"body_orientations",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"labels",
|
||||
"assembly_classification",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
@@ -586,6 +587,17 @@ class TestTrainingBatch:
|
||||
assert len(o) == 3
|
||||
assert len(o[0]) == 3
|
||||
|
||||
def test_labels_structure(self) -> None:
|
||||
"""Each example has labels dict with expected sub-keys."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
labels = ex["labels"]
|
||||
assert "per_constraint" in labels
|
||||
assert "per_joint" in labels
|
||||
assert "per_body" in labels
|
||||
assert "assembly" in labels
|
||||
|
||||
def test_grounded_field_present(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
|
||||
346
tests/datagen/test_labeling.py
Normal file
346
tests/datagen/test_labeling.py
Normal file
@@ -0,0 +1,346 @@
|
||||
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
|
||||
from solver.datagen.labeling import (
|
||||
label_assembly,
|
||||
)
|
||||
from solver.datagen.types import Joint, JointType, RigidBody
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
|
||||
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
|
||||
|
||||
|
||||
def _make_joint(
|
||||
jid: int,
|
||||
a: int,
|
||||
b: int,
|
||||
jtype: JointType,
|
||||
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
|
||||
) -> Joint:
|
||||
return Joint(
|
||||
joint_id=jid,
|
||||
body_a=a,
|
||||
body_b=b,
|
||||
joint_type=jtype,
|
||||
anchor_a=np.zeros(3),
|
||||
anchor_b=np.zeros(3),
|
||||
axis=np.array(axis),
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-constraint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstraintLabels:
|
||||
"""Per-constraint labels combine pebble game and Jacobian results."""
|
||||
|
||||
def test_fixed_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 6
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_revolute_joint_all_independent(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 5
|
||||
for cl in labels.per_constraint:
|
||||
assert cl.pebble_independent is True
|
||||
assert cl.jacobian_independent is True
|
||||
|
||||
def test_chain_constraint_count(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_constraint) == 10 # 5 + 5
|
||||
|
||||
def test_constraint_joint_ids(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
|
||||
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
|
||||
assert len(j0_constraints) == 5 # revolute
|
||||
assert len(j1_constraints) == 3 # ball
|
||||
|
||||
def test_overconstrained_has_pebble_redundant(self) -> None:
|
||||
"""Triangle with revolute joints: some constraints redundant."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
|
||||
assert pebble_redundant > 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-joint labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestJointLabels:
|
||||
"""Per-joint aggregated labels."""
|
||||
|
||||
def test_fixed_joint_counts(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_joint) == 1
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.joint_id == 0
|
||||
assert jl.independent_count == 6
|
||||
assert jl.redundant_count == 0
|
||||
assert jl.total == 6
|
||||
|
||||
def test_overconstrained_has_redundant_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
|
||||
assert total_redundant > 0
|
||||
|
||||
def test_joint_total_equals_dof(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.BALL)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
jl = labels.per_joint[0]
|
||||
assert jl.total == 3 # ball has 3 DOF
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Per-body DOF labels
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBodyDofLabels:
|
||||
"""Per-body DOF signatures from nullspace projection."""
|
||||
|
||||
def test_fixed_joint_grounded_both_zero(self) -> None:
|
||||
"""Two bodies + fixed joint + grounded: both fully constrained."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert bl.translational_dof == 0
|
||||
assert bl.rotational_dof == 0
|
||||
|
||||
def test_revolute_has_rotational_dof(self) -> None:
|
||||
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
# Revolute allows 1 rotation DOF
|
||||
assert b1.rotational_dof >= 1
|
||||
|
||||
def test_dof_bounds(self) -> None:
|
||||
"""All DOF values should be in [0, 3]."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
for bl in labels.per_body:
|
||||
assert 0 <= bl.translational_dof <= 3
|
||||
assert 0 <= bl.rotational_dof <= 3
|
||||
|
||||
def test_floating_more_dof_than_grounded(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
grounded = label_assembly(bodies, joints, ground_body=0)
|
||||
floating = label_assembly(bodies, joints, ground_body=None)
|
||||
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
|
||||
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
|
||||
assert f_total > g_total
|
||||
|
||||
def test_grounded_body_zero_dof(self) -> None:
|
||||
"""The grounded body should have 0 DOF."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
b0 = next(b for b in labels.per_body if b.body_id == 0)
|
||||
assert b0.translational_dof == 0
|
||||
assert b0.rotational_dof == 0
|
||||
|
||||
def test_body_count_matches(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.BALL),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert len(labels.per_body) == 3
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Assembly label
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAssemblyLabel:
|
||||
"""Assembly-wide summary labels."""
|
||||
|
||||
def test_underconstrained_chain(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
assert labels.assembly.is_rigid is False
|
||||
|
||||
def test_well_constrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.classification == "well-constrained"
|
||||
assert labels.assembly.is_rigid is True
|
||||
assert labels.assembly.is_minimally_rigid is True
|
||||
|
||||
def test_overconstrained(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE),
|
||||
_make_joint(2, 2, 0, JointType.REVOLUTE),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.redundant_count > 0
|
||||
|
||||
def test_has_degeneracy_with_parallel_axes(self) -> None:
|
||||
"""Parallel revolute axes in a loop create geometric degeneracy."""
|
||||
z_axis = (0.0, 0.0, 1.0)
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
|
||||
joints = [
|
||||
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
|
||||
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
|
||||
]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
assert labels.assembly.has_degeneracy is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Serialization
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToDict:
|
||||
"""to_dict produces JSON-serializable output."""
|
||||
|
||||
def test_top_level_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d.keys()) == {
|
||||
"per_constraint",
|
||||
"per_joint",
|
||||
"per_body",
|
||||
"assembly",
|
||||
}
|
||||
|
||||
def test_per_constraint_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
for item in d["per_constraint"]:
|
||||
assert set(item.keys()) == {
|
||||
"joint_id",
|
||||
"constraint_idx",
|
||||
"pebble_independent",
|
||||
"jacobian_independent",
|
||||
}
|
||||
|
||||
def test_assembly_keys(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
assert set(d["assembly"].keys()) == {
|
||||
"classification",
|
||||
"total_dof",
|
||||
"redundant_count",
|
||||
"is_rigid",
|
||||
"is_minimally_rigid",
|
||||
"has_degeneracy",
|
||||
}
|
||||
|
||||
def test_json_serializable(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
d = labels.to_dict()
|
||||
# Should not raise
|
||||
serialized = json.dumps(d)
|
||||
assert isinstance(serialized, str)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Edge cases
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLabelAssemblyEdgeCases:
|
||||
"""Edge cases for label_assembly."""
|
||||
|
||||
def test_no_joints(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=0)
|
||||
assert len(labels.per_constraint) == 0
|
||||
assert len(labels.per_joint) == 0
|
||||
assert labels.assembly.classification == "underconstrained"
|
||||
# Non-ground body should be fully free
|
||||
b1 = next(b for b in labels.per_body if b.body_id == 1)
|
||||
assert b1.translational_dof == 3
|
||||
assert b1.rotational_dof == 3
|
||||
|
||||
def test_no_joints_floating(self) -> None:
|
||||
bodies = _make_bodies((0, 0, 0))
|
||||
labels = label_assembly(bodies, [], ground_body=None)
|
||||
assert len(labels.per_body) == 1
|
||||
assert labels.per_body[0].translational_dof == 3
|
||||
assert labels.per_body[0].rotational_dof == 3
|
||||
|
||||
def test_analysis_embedded(self) -> None:
|
||||
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
|
||||
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
|
||||
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
|
||||
labels = label_assembly(bodies, joints, ground_body=0)
|
||||
analysis = labels.analysis
|
||||
assert hasattr(analysis, "combinatorial_classification")
|
||||
assert hasattr(analysis, "jacobian_rank")
|
||||
assert hasattr(analysis, "is_rigid")
|
||||
Reference in New Issue
Block a user