feat(datagen): add dataset generation CLI with sharding and checkpointing
- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator, ShardSpec/ShardResult dataclasses, parallel shard generation via ProcessPoolExecutor, checkpoint/resume support, index and stats output - Add scripts/generate_synthetic.py CLI entry point with Hydra-first and argparse fallback modes - Add minimal YAML parser (parse_simple_yaml) for config loading without PyYAML dependency - Add progress display with tqdm fallback to print-based ETA - Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every - Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator exports - Add tests/datagen/test_dataset.py with 28 tests covering config, YAML parsing, seed derivation, end-to-end generation, resume, stats/index structure, determinism, and CLI integration Closes #10
This commit is contained in:
337
tests/datagen/test_dataset.py
Normal file
337
tests/datagen/test_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
_derive_shard_seed,
|
||||
_parse_scalar,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetConfig:
|
||||
"""DatasetConfig construction and defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
cfg = DatasetConfig()
|
||||
assert cfg.num_assemblies == 100_000
|
||||
assert cfg.seed == 42
|
||||
assert cfg.shard_size == 1000
|
||||
assert cfg.num_workers == 4
|
||||
|
||||
def test_from_dict_flat(self) -> None:
|
||||
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.num_assemblies == 500
|
||||
assert cfg.seed == 123
|
||||
|
||||
def test_from_dict_nested_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 3
|
||||
assert cfg.body_count_max == 20
|
||||
|
||||
def test_from_dict_flat_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 5
|
||||
assert cfg.body_count_max == 30
|
||||
|
||||
def test_from_dict_complexity_distribution(self) -> None:
|
||||
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
|
||||
|
||||
def test_from_dict_templates(self) -> None:
|
||||
d: dict[str, Any] = {"templates": ["chain", "tree"]}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.templates == ["chain", "tree"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseScalar:
|
||||
"""_parse_scalar handles different value types."""
|
||||
|
||||
def test_int(self) -> None:
|
||||
assert _parse_scalar("42") == 42
|
||||
|
||||
def test_float(self) -> None:
|
||||
assert _parse_scalar("3.14") == 3.14
|
||||
|
||||
def test_bool_true(self) -> None:
|
||||
assert _parse_scalar("true") is True
|
||||
|
||||
def test_bool_false(self) -> None:
|
||||
assert _parse_scalar("false") is False
|
||||
|
||||
def test_string(self) -> None:
|
||||
assert _parse_scalar("hello") == "hello"
|
||||
|
||||
def test_inline_comment(self) -> None:
|
||||
assert _parse_scalar("0.4 # some comment") == 0.4
|
||||
|
||||
|
||||
class TestParseSimpleYaml:
|
||||
"""parse_simple_yaml handles the synthetic.yaml format."""
|
||||
|
||||
def test_flat_scalars(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["name"] == "test"
|
||||
assert result["num"] == 42
|
||||
assert result["ratio"] == 0.5
|
||||
|
||||
def test_nested_dict(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["body_count"] == {"min": 2, "max": 50}
|
||||
|
||||
def test_list(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["templates"] == ["chain", "tree", "loop"]
|
||||
|
||||
def test_inline_comments(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["dist"]["simple"] == 0.4
|
||||
|
||||
def test_synthetic_yaml(self) -> None:
|
||||
"""Parse the actual project config."""
|
||||
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
|
||||
assert result["name"] == "synthetic"
|
||||
assert result["num_assemblies"] == 100000
|
||||
assert isinstance(result["complexity_distribution"], dict)
|
||||
assert isinstance(result["templates"], list)
|
||||
assert result["shard_size"] == 1000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShardSeedDerivation:
|
||||
"""_derive_shard_seed is deterministic and unique per shard."""
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 0)
|
||||
assert s1 == s2
|
||||
|
||||
def test_different_shards(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 1)
|
||||
assert s1 != s2
|
||||
|
||||
def test_different_global_seeds(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(99, 0)
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetGenerator — small end-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetGenerator:
|
||||
"""End-to-end tests with small datasets."""
|
||||
|
||||
def test_small_generation(self, tmp_path: Path) -> None:
|
||||
"""Generate 10 examples in a single shard."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
assert shards_dir.exists()
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 1
|
||||
|
||||
index_file = tmp_path / "output" / "index.json"
|
||||
assert index_file.exists()
|
||||
index = json.loads(index_file.read_text())
|
||||
assert index["total_assemblies"] == 10
|
||||
|
||||
stats_file = tmp_path / "output" / "stats.json"
|
||||
assert stats_file.exists()
|
||||
|
||||
def test_multi_shard(self, tmp_path: Path) -> None:
|
||||
"""Generate 20 examples across 2 shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 2
|
||||
|
||||
def test_resume_skips_completed(self, tmp_path: Path) -> None:
|
||||
"""Resume skips already-completed shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Record shard modification times
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
|
||||
|
||||
# Remove stats (simulate incomplete) and re-run
|
||||
(tmp_path / "output" / "stats.json").unlink()
|
||||
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Shards should NOT have been regenerated
|
||||
for p in shards_dir.glob("shard_*"):
|
||||
assert p.stat().st_mtime == mtimes[p.name]
|
||||
|
||||
# Stats should be regenerated
|
||||
assert (tmp_path / "output" / "stats.json").exists()
|
||||
|
||||
def test_checkpoint_removed(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is cleaned up after completion."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
checkpoint = tmp_path / "output" / ".checkpoint.json"
|
||||
assert not checkpoint.exists()
|
||||
|
||||
def test_stats_structure(self, tmp_path: Path) -> None:
|
||||
"""stats.json has expected top-level keys."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
|
||||
assert stats["total_examples"] == 10
|
||||
assert "classification_distribution" in stats
|
||||
assert "body_count_histogram" in stats
|
||||
assert "joint_type_distribution" in stats
|
||||
assert "dof_statistics" in stats
|
||||
assert "geometric_degeneracy" in stats
|
||||
assert "rigidity" in stats
|
||||
|
||||
def test_index_structure(self, tmp_path: Path) -> None:
|
||||
"""index.json has expected format."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=15,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
index = json.loads((tmp_path / "output" / "index.json").read_text())
|
||||
assert index["format_version"] == 1
|
||||
assert index["total_assemblies"] == 15
|
||||
assert index["total_shards"] == 2
|
||||
assert "shards" in index
|
||||
for _name, info in index["shards"].items():
|
||||
assert "start_id" in info
|
||||
assert "count" in info
|
||||
|
||||
def test_deterministic_output(self, tmp_path: Path) -> None:
|
||||
"""Same seed produces same results."""
|
||||
for run_dir in ("run1", "run2"):
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / run_dir),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
|
||||
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
|
||||
assert s1["total_examples"] == s2["total_examples"]
|
||||
assert s1["classification_distribution"] == s2["classification_distribution"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Run the script via subprocess."""
|
||||
|
||||
def test_argparse_mode(self, tmp_path: Path) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/generate_synthetic.py",
|
||||
"--num-assemblies",
|
||||
"5",
|
||||
"--output-dir",
|
||||
str(tmp_path / "cli_out"),
|
||||
"--shard-size",
|
||||
"5",
|
||||
"--num-workers",
|
||||
"1",
|
||||
"--seed",
|
||||
"42",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/developer",
|
||||
timeout=120,
|
||||
env={**os.environ, "PYTHONPATH": "/home/developer"},
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
assert (tmp_path / "cli_out" / "index.json").exists()
|
||||
assert (tmp_path / "cli_out" / "stats.json").exists()
|
||||
Reference in New Issue
Block a user