feat: geometric diversity for synthetic assembly generation
- Add AxisStrategy type (cardinal, random, near_parallel) - Add random body orientations via scipy.spatial.transform.Rotation - Add parallel axis injection with configurable probability - Add grounded parameter on all 7 generators (grounded/floating) - Add axis sampling strategies: cardinal, random, near-parallel - Update _create_joint with orientation-aware anchor offsets - Add _resolve_axis helper for parallel axis propagation - Update generate_training_batch with axis_strategy, parallel_axis_prob, grounded_ratio parameters - Add body_orientations and grounded fields to batch output - Export AxisStrategy from datagen package - Add 28 new tests (72 total generator tests, 158 total) Closes #8
This commit is contained in:
@@ -44,7 +44,10 @@ class TestChainAssembly:
|
||||
|
||||
def test_chain_custom_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
_, joints, _ = gen.generate_chain_assembly(3, joint_type=JointType.BALL)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
3,
|
||||
joint_type=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
|
||||
@@ -77,7 +80,10 @@ class TestOverconstrainedAssembly:
|
||||
|
||||
def test_has_redundant(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(4, extra_joints=2)
|
||||
_, _, analysis = gen.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=2,
|
||||
)
|
||||
assert analysis.combinatorial_redundant > 0
|
||||
|
||||
def test_extra_joints_added(self) -> None:
|
||||
@@ -85,7 +91,10 @@ class TestOverconstrainedAssembly:
|
||||
_, joints_base, _ = gen.generate_rigid_assembly(4)
|
||||
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(4, extra_joints=3)
|
||||
_, joints_over, _ = gen2.generate_overconstrained_assembly(
|
||||
4,
|
||||
extra_joints=3,
|
||||
)
|
||||
# Overconstrained has base joints + extra
|
||||
assert len(joints_over) > len(joints_base)
|
||||
|
||||
@@ -111,7 +120,10 @@ class TestTreeAssembly:
|
||||
|
||||
def test_branching_factor(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(10, branching_factor=2)
|
||||
bodies, joints, _ = gen.generate_tree_assembly(
|
||||
10,
|
||||
branching_factor=2,
|
||||
)
|
||||
assert len(bodies) == 10
|
||||
assert len(joints) == 9
|
||||
|
||||
@@ -125,7 +137,10 @@ class TestTreeAssembly:
|
||||
|
||||
def test_single_joint_type(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_tree_assembly(5, joint_types=JointType.BALL)
|
||||
_, joints, _ = gen.generate_tree_assembly(
|
||||
5,
|
||||
joint_types=JointType.BALL,
|
||||
)
|
||||
assert all(j.joint_type is JointType.BALL for j in joints)
|
||||
|
||||
def test_sequential_body_ids(self) -> None:
|
||||
@@ -206,12 +221,18 @@ class TestMixedAssembly:
|
||||
|
||||
def test_more_joints_than_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(8, edge_density=0.3)
|
||||
bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
8,
|
||||
edge_density=0.3,
|
||||
)
|
||||
assert len(joints) > len(bodies) - 1
|
||||
|
||||
def test_density_zero_is_tree(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=0.0)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=0.0,
|
||||
)
|
||||
assert len(joints) == 4 # spanning tree only
|
||||
|
||||
def test_density_validation(self) -> None:
|
||||
@@ -229,11 +250,210 @@ class TestMixedAssembly:
|
||||
|
||||
def test_high_density(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(5, edge_density=1.0)
|
||||
_bodies, joints, _ = gen.generate_mixed_assembly(
|
||||
5,
|
||||
edge_density=1.0,
|
||||
)
|
||||
# Fully connected: 5*(5-1)/2 = 10 edges
|
||||
assert len(joints) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Axis sampling strategies
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestAxisStrategy:
|
||||
"""Axis sampling strategies produce valid unit vectors."""
|
||||
|
||||
def test_cardinal_axis_from_six(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axes = {tuple(gen._cardinal_axis()) for _ in range(200)}
|
||||
expected = {
|
||||
(1, 0, 0),
|
||||
(-1, 0, 0),
|
||||
(0, 1, 0),
|
||||
(0, -1, 0),
|
||||
(0, 0, 1),
|
||||
(0, 0, -1),
|
||||
}
|
||||
assert axes == expected
|
||||
|
||||
def test_random_axis_unit_norm(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
for _ in range(50):
|
||||
axis = gen._sample_axis("random")
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
|
||||
def test_near_parallel_close_to_base(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
base = np.array([0.0, 0.0, 1.0])
|
||||
for _ in range(50):
|
||||
axis = gen._near_parallel_axis(base)
|
||||
assert abs(np.linalg.norm(axis) - 1.0) < 1e-10
|
||||
assert np.dot(axis, base) > 0.95
|
||||
|
||||
def test_sample_axis_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("cardinal")
|
||||
cardinals = [
|
||||
np.array(v, dtype=float)
|
||||
for v in [(1, 0, 0), (-1, 0, 0), (0, 1, 0), (0, -1, 0), (0, 0, 1), (0, 0, -1)]
|
||||
]
|
||||
assert any(np.allclose(axis, c) for c in cardinals)
|
||||
|
||||
def test_sample_axis_near_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=0)
|
||||
axis = gen._sample_axis("near_parallel")
|
||||
z = np.array([0.0, 0.0, 1.0])
|
||||
assert np.dot(axis, z) > 0.95
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: orientations
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRandomOrientations:
|
||||
"""Bodies should have non-identity orientations."""
|
||||
|
||||
def test_bodies_have_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_tree_assembly(5)
|
||||
non_identity = sum(1 for b in bodies if not np.allclose(b.orientation, np.eye(3)))
|
||||
assert non_identity >= 3
|
||||
|
||||
def test_orientations_are_valid_rotations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
bodies, _, _ = gen.generate_star_assembly(6)
|
||||
for b in bodies:
|
||||
r = b.orientation
|
||||
# R^T R == I
|
||||
np.testing.assert_allclose(r.T @ r, np.eye(3), atol=1e-10)
|
||||
# det(R) == 1
|
||||
assert abs(np.linalg.det(r) - 1.0) < 1e-10
|
||||
|
||||
def test_all_generators_set_orientations(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
# Chain
|
||||
bodies, _, _ = gen.generate_chain_assembly(3)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Loop
|
||||
bodies, _, _ = gen.generate_loop_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
# Mixed
|
||||
bodies, _, _ = gen.generate_mixed_assembly(4)
|
||||
assert not np.allclose(bodies[1].orientation, np.eye(3))
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: grounded parameter
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGroundedParameter:
|
||||
"""Grounded parameter controls ground_body in analysis."""
|
||||
|
||||
def test_chain_grounded_default(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(4)
|
||||
assert analysis.combinatorial_dof >= 0
|
||||
|
||||
def test_chain_floating(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, analysis = gen.generate_chain_assembly(
|
||||
4,
|
||||
grounded=False,
|
||||
)
|
||||
# Floating: 6 trivial DOF not subtracted by ground
|
||||
assert analysis.combinatorial_dof >= 6
|
||||
|
||||
def test_floating_vs_grounded_dof_difference(self) -> None:
|
||||
gen1 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_grounded = gen1.generate_chain_assembly(4, grounded=True)
|
||||
gen2 = SyntheticAssemblyGenerator(seed=42)
|
||||
_, _, a_floating = gen2.generate_chain_assembly(4, grounded=False)
|
||||
# Floating should have higher DOF due to missing ground constraint
|
||||
assert a_floating.combinatorial_dof > a_grounded.combinatorial_dof
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gen_method",
|
||||
[
|
||||
"generate_chain_assembly",
|
||||
"generate_rigid_assembly",
|
||||
"generate_tree_assembly",
|
||||
"generate_loop_assembly",
|
||||
"generate_star_assembly",
|
||||
"generate_mixed_assembly",
|
||||
],
|
||||
)
|
||||
def test_all_generators_accept_grounded(
|
||||
self,
|
||||
gen_method: str,
|
||||
) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
method = getattr(gen, gen_method)
|
||||
n = 4
|
||||
# Should not raise
|
||||
if gen_method in ("generate_chain_assembly", "generate_rigid_assembly"):
|
||||
method(n, grounded=False)
|
||||
else:
|
||||
method(n, grounded=False)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Geometric diversity: parallel axis injection
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParallelAxisInjection:
|
||||
"""parallel_axis_prob causes shared axis direction."""
|
||||
|
||||
def test_parallel_axes_similar(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
# Near-parallel: |dot| close to 1
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_zero_prob_no_forced_parallel(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_chain_assembly(
|
||||
6,
|
||||
parallel_axis_prob=0.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
dots = [abs(np.dot(j.axis, base)) for j in joints[1:]]
|
||||
# With 5 random axes, extremely unlikely all are parallel
|
||||
assert min(dots) < 0.95
|
||||
|
||||
def test_parallel_on_loop(self) -> None:
|
||||
"""Parallel axes on a loop assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_loop_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
def test_parallel_on_star(self) -> None:
|
||||
"""Parallel axes on a star assembly."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
_, joints, _ = gen.generate_star_assembly(
|
||||
5,
|
||||
parallel_axis_prob=1.0,
|
||||
)
|
||||
base = joints[0].axis
|
||||
for j in joints[1:]:
|
||||
assert abs(np.dot(j.axis, base)) > 0.9
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Complexity tiers
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -265,7 +485,11 @@ class TestComplexityTiers:
|
||||
|
||||
def test_tier_overrides_range(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10, n_bodies_range=(2, 3), complexity_tier="medium")
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
n_bodies_range=(2, 3),
|
||||
complexity_tier="medium",
|
||||
)
|
||||
lo, hi = COMPLEXITY_RANGES["medium"]
|
||||
for ex in batch:
|
||||
assert lo <= ex["n_bodies"] < hi
|
||||
@@ -282,9 +506,11 @@ class TestTrainingBatch:
|
||||
EXPECTED_KEYS: ClassVar[set[str]] = {
|
||||
"example_id",
|
||||
"generator_type",
|
||||
"grounded",
|
||||
"n_bodies",
|
||||
"n_joints",
|
||||
"body_positions",
|
||||
"body_orientations",
|
||||
"joints",
|
||||
"joint_labels",
|
||||
"assembly_classification",
|
||||
@@ -337,7 +563,8 @@ class TestTrainingBatch:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(30)
|
||||
for ex in batch:
|
||||
assert 2 <= ex["n_bodies"] <= 7 # default (3, 8), but loop/star may clamp
|
||||
# default (3, 8), but loop/star may clamp
|
||||
assert 2 <= ex["n_bodies"] <= 7
|
||||
|
||||
def test_joint_label_consistency(self) -> None:
|
||||
"""independent + redundant == total for every joint."""
|
||||
@@ -348,6 +575,70 @@ class TestTrainingBatch:
|
||||
total = label["independent_constraints"] + label["redundant_constraints"]
|
||||
assert total == label["total_constraints"]
|
||||
|
||||
def test_body_orientations_present(self) -> None:
|
||||
"""Each example includes body_orientations as 3x3 lists."""
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
orients = ex["body_orientations"]
|
||||
assert len(orients) == ex["n_bodies"]
|
||||
for o in orients:
|
||||
assert len(o) == 3
|
||||
assert len(o[0]) == 3
|
||||
|
||||
def test_grounded_field_present(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(10)
|
||||
for ex in batch:
|
||||
assert isinstance(ex["grounded"], bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Batch grounded ratio
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestBatchGroundedRatio:
|
||||
"""grounded_ratio controls the mix in batch generation."""
|
||||
|
||||
def test_all_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=1.0)
|
||||
assert all(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_none_grounded(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(20, grounded_ratio=0.0)
|
||||
assert not any(ex["grounded"] for ex in batch)
|
||||
|
||||
def test_mixed_ratio(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(100, grounded_ratio=0.5)
|
||||
grounded_count = sum(1 for ex in batch if ex["grounded"])
|
||||
# With 100 samples and p=0.5, should be roughly 50 +/- 20
|
||||
assert 20 < grounded_count < 80
|
||||
|
||||
def test_batch_axis_strategy_cardinal(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
axis_strategy="cardinal",
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
def test_batch_parallel_axis_prob(self) -> None:
|
||||
gen = SyntheticAssemblyGenerator(seed=42)
|
||||
batch = gen.generate_training_batch(
|
||||
10,
|
||||
parallel_axis_prob=0.5,
|
||||
)
|
||||
assert len(batch) == 10
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed reproducibility
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSeedReproducibility:
|
||||
"""Different seeds produce different results."""
|
||||
@@ -355,8 +646,14 @@ class TestSeedReproducibility:
|
||||
def test_different_seeds_differ(self) -> None:
|
||||
g1 = SyntheticAssemblyGenerator(seed=1)
|
||||
g2 = SyntheticAssemblyGenerator(seed=2)
|
||||
b1 = g1.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b2 = g2.generate_training_batch(batch_size=5, n_bodies_range=(3, 6))
|
||||
b1 = g1.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
b2 = g2.generate_training_batch(
|
||||
batch_size=5,
|
||||
n_bodies_range=(3, 6),
|
||||
)
|
||||
c1 = [ex["assembly_classification"] for ex in b1]
|
||||
c2 = [ex["assembly_classification"] for ex in b2]
|
||||
r1 = [ex["is_rigid"] for ex in b1]
|
||||
|
||||
Reference in New Issue
Block a user