feat: geometric diversity for synthetic assembly generation
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

- Add AxisStrategy type (cardinal, random, near_parallel)
- Add random body orientations via scipy.spatial.transform.Rotation
- Add parallel axis injection with configurable probability
- Add grounded parameter on all 7 generators (grounded/floating)
- Add axis sampling strategies: cardinal, random, near-parallel
- Update _create_joint with orientation-aware anchor offsets
- Add _resolve_axis helper for parallel axis propagation
- Update generate_training_batch with axis_strategy, parallel_axis_prob,
  grounded_ratio parameters
- Add body_orientations and grounded fields to batch output
- Export AxisStrategy from datagen package
- Add 28 new tests (72 total generator tests, 158 total)

Closes #8
This commit is contained in:
2026-02-02 14:57:49 -06:00
parent 0b5813b5a9
commit 78289494e2
3 changed files with 710 additions and 80 deletions

View File

@@ -44,7 +44,10 @@ class TestChainAssembly:
def test_chain_custom_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
_, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL)
_, joints, _ = gen.generate_chain_assembly(
3,
joint_type=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
@@ -77,7 +80,10 @@ class TestOverconstrainedAssembly:
def test_has_redundant(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2)
_, _, analysis = gen.generate_overconstrained_assembly(
4,
extra_joints=2,
)
assert analysis.combinatorial_redundant > 0
def test_extra_joints_added(self) -> None:
@@ -85,7 +91,10 @@ class TestOverconstrainedAssembly:
_, joints_base, _ = gen.generate_rigid_assembly(4)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3)
_, joints_over, _ = gen2.generate_overconstrained_assembly(
4,
extra_joints=3,
)
# Overconstrained has base joints + extra
assert len(joints_over) > len(joints_base)
@@ -111,7 +120,10 @@ class TestTreeAssembly:
def test_branching_factor(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
bodies, joints, _ = gen.generate_tree_assembly(
10,
branching_factor=2,
)
assert len(bodies) == 10
assert len(joints) == 9
@@ -125,7 +137,10 @@ class TestTreeAssembly:
def test_single_joint_type(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
_, joints, _ = gen.generate_tree_assembly(
5,
joint_types=JointType.BALL,
)
assert all(j.joint_type is JointType.BALL for j in joints)
def test_sequential_body_ids(self) -> None:
@@ -206,12 +221,18 @@ class TestMixedAssembly:
def test_more_joints_than_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
bodies, joints, _ = gen.generate_mixed_assembly(
8,
edge_density=0.3,
)
assert len(joints) > len(bodies) - 1
def test_density_zero_is_tree(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=0.0,
)
assert len(joints) == 4 # spanning tree only
def test_density_validation(self) -> None:
@@ -229,11 +250,210 @@ class TestMixedAssembly:
def test_high_density(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
_bodies, joints, _ = gen.generate_mixed_assembly(
5,
edge_density=1.0,
)
# Fully connected: 5*(5-1)/2 = 10 edges
assert len(joints) == 10
# ---------------------------------------------------------------------------
# Axis sampling strategies
# ---------------------------------------------------------------------------
class TestAxisStrategy:
"""Axis sampling strategies produce valid unit vectors."""
def test_cardinal_axis_from_six(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
expected = {
(1, 0, 0),
(-1, 0, 0),
(0, 1, 0),
(0, -1, 0),
(0, 0, 1),
(0, 0, -1),
}
assert axes == expected
def test_random_axis_unit_norm(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
for _ in range(50):
axis = gen._sample_axis("random")
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
def test_near_parallel_close_to_base(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
base = np.array([0.0, 0.0, 1.0])
for _ in range(50):
axis = gen._near_parallel_axis(base)
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
assert np.dot(axis, base) > 0.95
def test_sample_axis_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("cardinal")
cardinals = [
np.array(v, dtype=float)
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
]
assert any(np.allclose(axis, c) for c in cardinals)
def test_sample_axis_near_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=0)
axis = gen._sample_axis("near_parallel")
z = np.array([0.0, 0.0, 1.0])
assert np.dot(axis, z) > 0.95
# ---------------------------------------------------------------------------
# Geometric diversity: orientations
# ---------------------------------------------------------------------------
class TestRandomOrientations:
"""Bodies should have non-identity orientations."""
def test_bodies_have_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_tree_assembly(5)
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
assert non_identity >= 3
def test_orientations_are_valid_rotations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
bodies, _, _ = gen.generate_star_assembly(6)
for b in bodies:
r = b.orientation
# R^T R == I
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
# det(R) == 1
assert abs(np.linalg.det(r) - 1.0) < 1e-10
def test_all_generators_set_orientations(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
# Chain
bodies, _, _ = gen.generate_chain_assembly(3)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Loop
bodies, _, _ = gen.generate_loop_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# Mixed
bodies, _, _ = gen.generate_mixed_assembly(4)
assert not np.allclose(bodies[1].orientation, np.eye(3))
# ---------------------------------------------------------------------------
# Geometric diversity: grounded parameter
# ---------------------------------------------------------------------------
class TestGroundedParameter:
"""Grounded parameter controls ground_body in analysis."""
def test_chain_grounded_default(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(4)
assert analysis.combinatorial_dof >= 0
def test_chain_floating(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, _, analysis = gen.generate_chain_assembly(
4,
grounded=False,
)
# Floating: 6 trivial DOF not subtracted by ground
assert analysis.combinatorial_dof >= 6
def test_floating_vs_grounded_dof_difference(self) -> None:
gen1 = SyntheticAssemblyGenerator(seed=42)
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
gen2 = SyntheticAssemblyGenerator(seed=42)
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
# Floating should have higher DOF due to missing ground constraint
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
@pytest.mark.parametrize(
"gen_method",
[
"generate_chain_assembly",
"generate_rigid_assembly",
"generate_tree_assembly",
"generate_loop_assembly",
"generate_star_assembly",
"generate_mixed_assembly",
],
)
def test_all_generators_accept_grounded(
self,
gen_method: str,
) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
method = getattr(gen, gen_method)
n = 4
# Should not raise
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
method(n, grounded=False)
else:
method(n, grounded=False)
# ---------------------------------------------------------------------------
# Geometric diversity: parallel axis injection
# ---------------------------------------------------------------------------
class TestParallelAxisInjection:
"""parallel_axis_prob causes shared axis direction."""
def test_parallel_axes_similar(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
# Near-parallel: |dot| close to 1
assert abs(np.dot(j.axis, base)) > 0.9
def test_zero_prob_no_forced_parallel(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_chain_assembly(
6,
parallel_axis_prob=0.0,
)
base = joints[0].axis
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
# With 5 random axes, extremely unlikely all are parallel
assert min(dots) < 0.95
def test_parallel_on_loop(self) -> None:
"""Parallel axes on a loop assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_loop_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
def test_parallel_on_star(self) -> None:
"""Parallel axes on a star assembly."""
gen = SyntheticAssemblyGenerator(seed=42)
_, joints, _ = gen.generate_star_assembly(
5,
parallel_axis_prob=1.0,
)
base = joints[0].axis
for j in joints[1:]:
assert abs(np.dot(j.axis, base)) > 0.9
# ---------------------------------------------------------------------------
# Complexity tiers
# ---------------------------------------------------------------------------
@@ -265,7 +485,11 @@ class TestComplexityTiers:
def test_tier_overrides_range(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
batch = gen.generate_training_batch(
10,
n_bodies_range=(2, 3),
complexity_tier="medium",
)
lo, hi = COMPLEXITY_RANGES["medium"]
for ex in batch:
assert lo <= ex["n_bodies"] < hi
@@ -282,9 +506,11 @@ class TestTrainingBatch:
EXPECTED_KEYS: ClassVar[set[str]] = {
"example_id",
"generator_type",
"grounded",
"n_bodies",
"n_joints",
"body_positions",
"body_orientations",
"joints",
"joint_labels",
"assembly_classification",
@@ -337,7 +563,8 @@ class TestTrainingBatch:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(30)
for ex in batch:
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
# default (3, 8), but loop/star may clamp
assert 2 <= ex["n_bodies"] <= 7
def test_joint_label_consistency(self) -> None:
"""independent + redundant == total for every joint."""
@@ -348,6 +575,70 @@ class TestTrainingBatch:
total = label["independent_constraints"] + label["redundant_constraints"]
assert total == label["total_constraints"]
def test_body_orientations_present(self) -> None:
"""Each example includes body_orientations as 3x3 lists."""
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
orients = ex["body_orientations"]
assert len(orients) == ex["n_bodies"]
for o in orients:
assert len(o) == 3
assert len(o[0]) == 3
def test_grounded_field_present(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(10)
for ex in batch:
assert isinstance(ex["grounded"], bool)
# ---------------------------------------------------------------------------
# Batch grounded ratio
# ---------------------------------------------------------------------------
class TestBatchGroundedRatio:
"""grounded_ratio controls the mix in batch generation."""
def test_all_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
assert all(ex["grounded"] for ex in batch)
def test_none_grounded(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
assert not any(ex["grounded"] for ex in batch)
def test_mixed_ratio(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
grounded_count = sum(1 for ex in batch if ex["grounded"])
# With 100 samples and p=0.5, should be roughly 50 +/- 20
assert 20 < grounded_count < 80
def test_batch_axis_strategy_cardinal(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
axis_strategy="cardinal",
)
assert len(batch) == 10
def test_batch_parallel_axis_prob(self) -> None:
gen = SyntheticAssemblyGenerator(seed=42)
batch = gen.generate_training_batch(
10,
parallel_axis_prob=0.5,
)
assert len(batch) == 10
# ---------------------------------------------------------------------------
# Seed reproducibility
# ---------------------------------------------------------------------------
class TestSeedReproducibility:
"""Different seeds produce different results."""
@@ -355,8 +646,14 @@ class TestSeedReproducibility:
def test_different_seeds_differ(self) -> None:
g1 = SyntheticAssemblyGenerator(seed=1)
g2 = SyntheticAssemblyGenerator(seed=2)
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
b1 = g1.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
b2 = g2.generate_training_batch(
batch_size=5,
n_bodies_range=(3, 6),
)
c1 = [ex["assembly_classification"] for ex in b1]
c2 = [ex["assembly_classification"] for ex in b2]
r1 = [ex["is_rigid"] for ex in b1]