feat: ground truth labeling pipeline
Some checks failed
CI / lint (push) Failing after 25m6s
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

- Create solver/datagen/labeling.py with label_assembly() function
- Add dataclasses: ConstraintLabel, JointLabel, BodyDofLabel,
  AssemblyLabel, AssemblyLabels
- Per-constraint labels: pebble_independent + jacobian_independent
- Per-joint labels: aggregated independent/redundant/total counts
- Per-body DOF: translational + rotational from nullspace projection
- Assembly label: classification, total_dof, has_degeneracy flag
- AssemblyLabels.to_dict() for JSON-serializable output
- Integrate into generate_training_batch (adds 'labels' field)
- Export AssemblyLabels and label_assembly from datagen package
- Add 25 labeling tests + 1 batch structure test (184 total)

Closes #9
This commit is contained in:
2026-02-02 15:20:02 -06:00
parent 78289494e2
commit 8a49f8ef40
5 changed files with 762 additions and 0 deletions

View File

@@ -513,6 +513,7 @@ class TestTrainingBatch:
"body_orientations",
"joints",
"joint_labels",
"labels",
"assembly_classification",
"is_rigid",
"is_minimally_rigid",
@@ -586,6 +587,17 @@ class TestTrainingBatch:
assert len(o) == 3
assert len(o[0]) == 3
def test_labels_structure(self) -> None:
"""Each example has labels dict with expected sub-keys."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
labels = ex["labels"]
assert "per_constraint" in labels
assert "per_joint" in labels
assert "per_body" in labels
assert "assembly" in labels
def test_grounded_field_present(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)

View File

@@ -0,0 +1,346 @@
"""Tests for solver.datagen.labeling -- ground truth labeling pipeline."""
from __future__ import annotations
import json
import numpy as np
from solver.datagen.labeling import (
label_assembly,
)
from solver.datagen.types import Joint, JointType, RigidBody
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_bodies(*positions: tuple[float, ...]) -> list[RigidBody]:
return [RigidBody(body_id=i, position=np.array(pos)) for i, pos in enumerate(positions)]
def _make_joint(
jid: int,
a: int,
b: int,
jtype: JointType,
axis: tuple[float, ...] = (0.0, 0.0, 1.0),
) -> Joint:
return Joint(
joint_id=jid,
body_a=a,
body_b=b,
joint_type=jtype,
anchor_a=np.zeros(3),
anchor_b=np.zeros(3),
axis=np.array(axis),
)
# ---------------------------------------------------------------------------
# Per-constraint labels
# ---------------------------------------------------------------------------
class TestConstraintLabels:
"""Per-constraint labels combine pebble game and Jacobian results."""
def test_fixed_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 6
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_revolute_joint_all_independent(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 5
for cl in labels.per_constraint:
assert cl.pebble_independent is True
assert cl.jacobian_independent is True
def test_chain_constraint_count(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_constraint) == 10 # 5 + 5
def test_constraint_joint_ids(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
j0_constraints = [c for c in labels.per_constraint if c.joint_id == 0]
j1_constraints = [c for c in labels.per_constraint if c.joint_id == 1]
assert len(j0_constraints) == 5 # revolute
assert len(j1_constraints) == 3 # ball
def test_overconstrained_has_pebble_redundant(self) -> None:
"""Triangle with revolute joints: some constraints redundant."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
pebble_redundant = sum(1 for c in labels.per_constraint if not c.pebble_independent)
assert pebble_redundant > 0
# ---------------------------------------------------------------------------
# Per-joint labels
# ---------------------------------------------------------------------------
class TestJointLabels:
"""Per-joint aggregated labels."""
def test_fixed_joint_counts(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_joint) == 1
jl = labels.per_joint[0]
assert jl.joint_id == 0
assert jl.independent_count == 6
assert jl.redundant_count == 0
assert jl.total == 6
def test_overconstrained_has_redundant_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
total_redundant = sum(jl.redundant_count for jl in labels.per_joint)
assert total_redundant > 0
def test_joint_total_equals_dof(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.BALL)]
labels = label_assembly(bodies, joints, ground_body=0)
jl = labels.per_joint[0]
assert jl.total == 3 # ball has 3 DOF
# ---------------------------------------------------------------------------
# Per-body DOF labels
# ---------------------------------------------------------------------------
class TestBodyDofLabels:
"""Per-body DOF signatures from nullspace projection."""
def test_fixed_joint_grounded_both_zero(self) -> None:
"""Two bodies + fixed joint + grounded: both fully constrained."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert bl.translational_dof == 0
assert bl.rotational_dof == 0
def test_revolute_has_rotational_dof(self) -> None:
"""Two bodies + revolute + grounded: body 1 has rotational DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b1 = next(b for b in labels.per_body if b.body_id == 1)
# Revolute allows 1 rotation DOF
assert b1.rotational_dof >= 1
def test_dof_bounds(self) -> None:
"""All DOF values should be in [0, 3]."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
for bl in labels.per_body:
assert 0 <= bl.translational_dof <= 3
assert 0 <= bl.rotational_dof <= 3
def test_floating_more_dof_than_grounded(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
grounded = label_assembly(bodies, joints, ground_body=0)
floating = label_assembly(bodies, joints, ground_body=None)
g_total = sum(b.translational_dof + b.rotational_dof for b in grounded.per_body)
f_total = sum(b.translational_dof + b.rotational_dof for b in floating.per_body)
assert f_total > g_total
def test_grounded_body_zero_dof(self) -> None:
"""The grounded body should have 0 DOF."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
b0 = next(b for b in labels.per_body if b.body_id == 0)
assert b0.translational_dof == 0
assert b0.rotational_dof == 0
def test_body_count_matches(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.BALL),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert len(labels.per_body) == 3
# ---------------------------------------------------------------------------
# Assembly label
# ---------------------------------------------------------------------------
class TestAssemblyLabel:
"""Assembly-wide summary labels."""
def test_underconstrained_chain(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (4, 0, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "underconstrained"
assert labels.assembly.is_rigid is False
def test_well_constrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.FIXED)]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.classification == "well-constrained"
assert labels.assembly.is_rigid is True
assert labels.assembly.is_minimally_rigid is True
def test_overconstrained(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (1, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE),
_make_joint(1, 1, 2, JointType.REVOLUTE),
_make_joint(2, 2, 0, JointType.REVOLUTE),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.redundant_count > 0
def test_has_degeneracy_with_parallel_axes(self) -> None:
"""Parallel revolute axes in a loop create geometric degeneracy."""
z_axis = (0.0, 0.0, 1.0)
bodies = _make_bodies((0, 0, 0), (2, 0, 0), (2, 2, 0), (0, 2, 0))
joints = [
_make_joint(0, 0, 1, JointType.REVOLUTE, axis=z_axis),
_make_joint(1, 1, 2, JointType.REVOLUTE, axis=z_axis),
_make_joint(2, 2, 3, JointType.REVOLUTE, axis=z_axis),
_make_joint(3, 3, 0, JointType.REVOLUTE, axis=z_axis),
]
labels = label_assembly(bodies, joints, ground_body=0)
assert labels.assembly.has_degeneracy is True
# ---------------------------------------------------------------------------
# Serialization
# ---------------------------------------------------------------------------
class TestToDict:
"""to_dict produces JSON-serializable output."""
def test_top_level_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d.keys()) == {
"per_constraint",
"per_joint",
"per_body",
"assembly",
}
def test_per_constraint_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
for item in d["per_constraint"]:
assert set(item.keys()) == {
"joint_id",
"constraint_idx",
"pebble_independent",
"jacobian_independent",
}
def test_assembly_keys(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
assert set(d["assembly"].keys()) == {
"classification",
"total_dof",
"redundant_count",
"is_rigid",
"is_minimally_rigid",
"has_degeneracy",
}
def test_json_serializable(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
d = labels.to_dict()
# Should not raise
serialized = json.dumps(d)
assert isinstance(serialized, str)
# ---------------------------------------------------------------------------
# Edge cases
# ---------------------------------------------------------------------------
class TestLabelAssemblyEdgeCases:
"""Edge cases for label_assembly."""
def test_no_joints(self) -> None:
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
labels = label_assembly(bodies, [], ground_body=0)
assert len(labels.per_constraint) == 0
assert len(labels.per_joint) == 0
assert labels.assembly.classification == "underconstrained"
# Non-ground body should be fully free
b1 = next(b for b in labels.per_body if b.body_id == 1)
assert b1.translational_dof == 3
assert b1.rotational_dof == 3
def test_no_joints_floating(self) -> None:
bodies = _make_bodies((0, 0, 0))
labels = label_assembly(bodies, [], ground_body=None)
assert len(labels.per_body) == 1
assert labels.per_body[0].translational_dof == 3
assert labels.per_body[0].rotational_dof == 3
def test_analysis_embedded(self) -> None:
"""AssemblyLabels.analysis should be a valid ConstraintAnalysis."""
bodies = _make_bodies((0, 0, 0), (2, 0, 0))
joints = [_make_joint(0, 0, 1, JointType.REVOLUTE)]
labels = label_assembly(bodies, joints, ground_body=0)
analysis = labels.analysis
assert hasattr(analysis, "combinatorial_classification")
assert hasattr(analysis, "jacobian_rank")
assert hasattr(analysis, "is_rigid")