feat(datagen): add dataset generation CLI with sharding and checkpointing
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled

- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator,
  ShardSpec/ShardResult dataclasses, parallel shard generation via
  ProcessPoolExecutor, checkpoint/resume support, index and stats output
- Add scripts/generate_synthetic.py CLI entry point with Hydra-first
  and argparse fallback modes
- Add minimal YAML parser (parse_simple_yaml) for config loading
  without PyYAML dependency
- Add progress display with tqdm fallback to print-based ETA
- Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every
- Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator
  exports
- Add tests/datagen/test_dataset.py with 28 tests covering config,
  YAML parsing, seed derivation, end-to-end generation, resume,
  stats/index structure, determinism, and CLI integration

Closes #10
This commit is contained in:
2026-02-03 08:44:31 -06:00
parent 8a49f8ef40
commit f29060491e
5 changed files with 1081 additions and 0 deletions

View File

@@ -0,0 +1,337 @@
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
from __future__ import annotations
import json
import os
import subprocess
import sys
from typing import TYPE_CHECKING
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
_derive_shard_seed,
_parse_scalar,
parse_simple_yaml,
)
if TYPE_CHECKING:
from pathlib import Path
from typing import Any
# ---------------------------------------------------------------------------
# DatasetConfig
# ---------------------------------------------------------------------------
class TestDatasetConfig:
"""DatasetConfig construction and defaults."""
def test_defaults(self) -> None:
cfg = DatasetConfig()
assert cfg.num_assemblies == 100_000
assert cfg.seed == 42
assert cfg.shard_size == 1000
assert cfg.num_workers == 4
def test_from_dict_flat(self) -> None:
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
cfg = DatasetConfig.from_dict(d)
assert cfg.num_assemblies == 500
assert cfg.seed == 123
def test_from_dict_nested_body_count(self) -> None:
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 3
assert cfg.body_count_max == 20
def test_from_dict_flat_body_count(self) -> None:
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
cfg = DatasetConfig.from_dict(d)
assert cfg.body_count_min == 5
assert cfg.body_count_max == 30
def test_from_dict_complexity_distribution(self) -> None:
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
cfg = DatasetConfig.from_dict(d)
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
def test_from_dict_templates(self) -> None:
d: dict[str, Any] = {"templates": ["chain", "tree"]}
cfg = DatasetConfig.from_dict(d)
assert cfg.templates == ["chain", "tree"]
# ---------------------------------------------------------------------------
# Minimal YAML parser
# ---------------------------------------------------------------------------
class TestParseScalar:
"""_parse_scalar handles different value types."""
def test_int(self) -> None:
assert _parse_scalar("42") == 42
def test_float(self) -> None:
assert _parse_scalar("3.14") == 3.14
def test_bool_true(self) -> None:
assert _parse_scalar("true") is True
def test_bool_false(self) -> None:
assert _parse_scalar("false") is False
def test_string(self) -> None:
assert _parse_scalar("hello") == "hello"
def test_inline_comment(self) -> None:
assert _parse_scalar("0.4 # some comment") == 0.4
class TestParseSimpleYaml:
"""parse_simple_yaml handles the synthetic.yaml format."""
def test_flat_scalars(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
result = parse_simple_yaml(str(yaml_file))
assert result["name"] == "test"
assert result["num"] == 42
assert result["ratio"] == 0.5
def test_nested_dict(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
result = parse_simple_yaml(str(yaml_file))
assert result["body_count"] == {"min": 2, "max": 50}
def test_list(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
result = parse_simple_yaml(str(yaml_file))
assert result["templates"] == ["chain", "tree", "loop"]
def test_inline_comments(self, tmp_path: Path) -> None:
yaml_file = tmp_path / "test.yaml"
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
result = parse_simple_yaml(str(yaml_file))
assert result["dist"]["simple"] == 0.4
def test_synthetic_yaml(self) -> None:
"""Parse the actual project config."""
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
assert result["name"] == "synthetic"
assert result["num_assemblies"] == 100000
assert isinstance(result["complexity_distribution"], dict)
assert isinstance(result["templates"], list)
assert result["shard_size"] == 1000
# ---------------------------------------------------------------------------
# Shard seed derivation
# ---------------------------------------------------------------------------
class TestShardSeedDerivation:
"""_derive_shard_seed is deterministic and unique per shard."""
def test_deterministic(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 0)
assert s1 == s2
def test_different_shards(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(42, 1)
assert s1 != s2
def test_different_global_seeds(self) -> None:
s1 = _derive_shard_seed(42, 0)
s2 = _derive_shard_seed(99, 0)
assert s1 != s2
# ---------------------------------------------------------------------------
# DatasetGenerator — small end-to-end tests
# ---------------------------------------------------------------------------
class TestDatasetGenerator:
"""End-to-end tests with small datasets."""
def test_small_generation(self, tmp_path: Path) -> None:
"""Generate 10 examples in a single shard."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
assert shards_dir.exists()
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 1
index_file = tmp_path / "output" / "index.json"
assert index_file.exists()
index = json.loads(index_file.read_text())
assert index["total_assemblies"] == 10
stats_file = tmp_path / "output" / "stats.json"
assert stats_file.exists()
def test_multi_shard(self, tmp_path: Path) -> None:
"""Generate 20 examples across 2 shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
shards_dir = tmp_path / "output" / "shards"
shard_files = sorted(shards_dir.glob("shard_*"))
assert len(shard_files) == 2
def test_resume_skips_completed(self, tmp_path: Path) -> None:
"""Resume skips already-completed shards."""
cfg = DatasetConfig(
num_assemblies=20,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
# Record shard modification times
shards_dir = tmp_path / "output" / "shards"
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
# Remove stats (simulate incomplete) and re-run
(tmp_path / "output" / "stats.json").unlink()
DatasetGenerator(cfg).run()
# Shards should NOT have been regenerated
for p in shards_dir.glob("shard_*"):
assert p.stat().st_mtime == mtimes[p.name]
# Stats should be regenerated
assert (tmp_path / "output" / "stats.json").exists()
def test_checkpoint_removed(self, tmp_path: Path) -> None:
"""Checkpoint file is cleaned up after completion."""
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / "output"),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
checkpoint = tmp_path / "output" / ".checkpoint.json"
assert not checkpoint.exists()
def test_stats_structure(self, tmp_path: Path) -> None:
"""stats.json has expected top-level keys."""
cfg = DatasetConfig(
num_assemblies=10,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
assert stats["total_examples"] == 10
assert "classification_distribution" in stats
assert "body_count_histogram" in stats
assert "joint_type_distribution" in stats
assert "dof_statistics" in stats
assert "geometric_degeneracy" in stats
assert "rigidity" in stats
def test_index_structure(self, tmp_path: Path) -> None:
"""index.json has expected format."""
cfg = DatasetConfig(
num_assemblies=15,
output_dir=str(tmp_path / "output"),
shard_size=10,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
index = json.loads((tmp_path / "output" / "index.json").read_text())
assert index["format_version"] == 1
assert index["total_assemblies"] == 15
assert index["total_shards"] == 2
assert "shards" in index
for _name, info in index["shards"].items():
assert "start_id" in info
assert "count" in info
def test_deterministic_output(self, tmp_path: Path) -> None:
"""Same seed produces same results."""
for run_dir in ("run1", "run2"):
cfg = DatasetConfig(
num_assemblies=5,
output_dir=str(tmp_path / run_dir),
shard_size=5,
seed=42,
num_workers=1,
)
DatasetGenerator(cfg).run()
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
assert s1["total_examples"] == s2["total_examples"]
assert s1["classification_distribution"] == s2["classification_distribution"]
# ---------------------------------------------------------------------------
# CLI integration test
# ---------------------------------------------------------------------------
class TestCLI:
"""Run the script via subprocess."""
def test_argparse_mode(self, tmp_path: Path) -> None:
result = subprocess.run(
[
sys.executable,
"scripts/generate_synthetic.py",
"--num-assemblies",
"5",
"--output-dir",
str(tmp_path / "cli_out"),
"--shard-size",
"5",
"--num-workers",
"1",
"--seed",
"42",
],
capture_output=True,
text=True,
cwd="/home/developer",
timeout=120,
env={**os.environ, "PYTHONPATH": "/home/developer"},
)
assert result.returncode == 0, (
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
)
assert (tmp_path / "cli_out" / "index.json").exists()
assert (tmp_path / "cli_out" / "stats.json").exists()