feat(datagen): add dataset generation CLI with sharding and checkpointing
- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator, ShardSpec/ShardResult dataclasses, parallel shard generation via ProcessPoolExecutor, checkpoint/resume support, index and stats output - Add scripts/generate_synthetic.py CLI entry point with Hydra-first and argparse fallback modes - Add minimal YAML parser (parse_simple_yaml) for config loading without PyYAML dependency - Add progress display with tqdm fallback to print-based ETA - Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every - Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator exports - Add tests/datagen/test_dataset.py with 28 tests covering config, YAML parsing, seed derivation, end-to-end generation, resume, stats/index structure, determinism, and CLI integration Closes #10
This commit is contained in:
@@ -2,6 +2,7 @@
|
||||
name: synthetic
|
||||
num_assemblies: 100000
|
||||
output_dir: data/synthetic
|
||||
shard_size: 1000
|
||||
|
||||
complexity_distribution:
|
||||
simple: 0.4 # 2-5 bodies
|
||||
@@ -22,3 +23,4 @@ templates:
|
||||
grounded_ratio: 0.5
|
||||
seed: 42
|
||||
num_workers: 4
|
||||
checkpoint_every: 5
|
||||
|
||||
115
scripts/generate_synthetic.py
Normal file
115
scripts/generate_synthetic.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate synthetic assembly dataset for kindred-solver training.
|
||||
|
||||
Usage (argparse fallback — always available)::
|
||||
|
||||
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
||||
|
||||
Usage (Hydra — when hydra-core is installed)::
|
||||
|
||||
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _try_hydra_main() -> bool:
|
||||
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
||||
try:
|
||||
import hydra # type: ignore[import-untyped]
|
||||
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@hydra.main(
|
||||
config_path="../configs/dataset",
|
||||
config_name="synthetic",
|
||||
version_base=None,
|
||||
)
|
||||
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
|
||||
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
_run() # type: ignore[no-untyped-call]
|
||||
return True
|
||||
|
||||
|
||||
def _argparse_main() -> None:
|
||||
"""Fallback CLI using argparse."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic assembly dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to YAML config file (optional)",
|
||||
)
|
||||
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
||||
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
||||
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
||||
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
||||
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
||||
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Checkpoint interval (shards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Do not resume from existing checkpoints",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
config_dict: dict[str, object] = {}
|
||||
if args.config:
|
||||
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
||||
|
||||
# CLI args override config file (only when explicitly provided)
|
||||
_override_map = {
|
||||
"num_assemblies": args.num_assemblies,
|
||||
"output_dir": args.output_dir,
|
||||
"shard_size": args.shard_size,
|
||||
"body_count_min": args.body_count_min,
|
||||
"body_count_max": args.body_count_max,
|
||||
"grounded_ratio": args.grounded_ratio,
|
||||
"seed": args.seed,
|
||||
"num_workers": args.num_workers,
|
||||
"checkpoint_every": args.checkpoint_every,
|
||||
}
|
||||
for key, val in _override_map.items():
|
||||
if val is not None:
|
||||
config_dict[key] = val
|
||||
|
||||
if args.no_resume:
|
||||
config_dict["resume"] = False
|
||||
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point: try Hydra first, fall back to argparse."""
|
||||
if not _try_hydra_main():
|
||||
_argparse_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,6 +1,7 @@
|
||||
"""Data generation utilities for assembly constraint training data."""
|
||||
|
||||
from solver.datagen.analysis import analyze_assembly
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
from solver.datagen.generator import (
|
||||
COMPLEXITY_RANGES,
|
||||
AxisStrategy,
|
||||
@@ -22,6 +23,8 @@ __all__ = [
|
||||
"AssemblyLabels",
|
||||
"AxisStrategy",
|
||||
"ConstraintAnalysis",
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"JacobianVerifier",
|
||||
"Joint",
|
||||
"JointType",
|
||||
|
||||
624
solver/datagen/dataset.py
Normal file
624
solver/datagen/dataset.py
Normal file
@@ -0,0 +1,624 @@
|
||||
"""Dataset generation orchestrator with sharding, checkpointing, and statistics.
|
||||
|
||||
Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator`
|
||||
for parallel generation of synthetic assembly training data.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import numpy as np
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
__all__ = [
|
||||
"DatasetConfig",
|
||||
"DatasetGenerator",
|
||||
"parse_simple_yaml",
|
||||
]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Configuration
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for synthetic dataset generation."""
|
||||
|
||||
name: str = "synthetic"
|
||||
num_assemblies: int = 100_000
|
||||
output_dir: str = "data/synthetic"
|
||||
shard_size: int = 1000
|
||||
complexity_distribution: dict[str, float] = field(
|
||||
default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2}
|
||||
)
|
||||
body_count_min: int = 2
|
||||
body_count_max: int = 50
|
||||
templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"])
|
||||
grounded_ratio: float = 0.5
|
||||
seed: int = 42
|
||||
num_workers: int = 4
|
||||
checkpoint_every: int = 5
|
||||
resume: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict[str, Any]) -> DatasetConfig:
|
||||
"""Construct from a parsed config dict (e.g. YAML or OmegaConf).
|
||||
|
||||
Handles both flat keys (``body_count_min``) and nested forms
|
||||
(``body_count: {min: 2, max: 50}``).
|
||||
"""
|
||||
kw: dict[str, Any] = {}
|
||||
for key in (
|
||||
"name",
|
||||
"num_assemblies",
|
||||
"output_dir",
|
||||
"shard_size",
|
||||
"grounded_ratio",
|
||||
"seed",
|
||||
"num_workers",
|
||||
"checkpoint_every",
|
||||
"resume",
|
||||
):
|
||||
if key in d:
|
||||
kw[key] = d[key]
|
||||
|
||||
# Handle nested body_count dict
|
||||
if "body_count" in d and isinstance(d["body_count"], dict):
|
||||
bc = d["body_count"]
|
||||
if "min" in bc:
|
||||
kw["body_count_min"] = int(bc["min"])
|
||||
if "max" in bc:
|
||||
kw["body_count_max"] = int(bc["max"])
|
||||
else:
|
||||
if "body_count_min" in d:
|
||||
kw["body_count_min"] = int(d["body_count_min"])
|
||||
if "body_count_max" in d:
|
||||
kw["body_count_max"] = int(d["body_count_max"])
|
||||
|
||||
if "complexity_distribution" in d:
|
||||
cd = d["complexity_distribution"]
|
||||
if isinstance(cd, dict):
|
||||
kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()}
|
||||
if "templates" in d and isinstance(d["templates"], list):
|
||||
kw["templates"] = [str(t) for t in d["templates"]]
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard specification / result
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardSpec:
|
||||
"""Specification for generating a single shard."""
|
||||
|
||||
shard_id: int
|
||||
start_example_id: int
|
||||
count: int
|
||||
seed: int
|
||||
complexity_distribution: dict[str, float]
|
||||
body_count_min: int
|
||||
body_count_max: int
|
||||
grounded_ratio: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShardResult:
|
||||
"""Result returned from a shard worker."""
|
||||
|
||||
shard_id: int
|
||||
num_examples: int
|
||||
file_path: str
|
||||
generation_time_s: float
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _derive_shard_seed(global_seed: int, shard_id: int) -> int:
|
||||
"""Derive a deterministic per-shard seed from the global seed."""
|
||||
h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest()
|
||||
return int(h[:8], 16)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Progress display
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PrintProgress:
|
||||
"""Fallback progress display when tqdm is unavailable."""
|
||||
|
||||
def __init__(self, total: int) -> None:
|
||||
self.total = total
|
||||
self.current = 0
|
||||
self.start_time = time.monotonic()
|
||||
|
||||
def update(self, n: int = 1) -> None:
|
||||
self.current += n
|
||||
elapsed = time.monotonic() - self.start_time
|
||||
rate = self.current / elapsed if elapsed > 0 else 0.0
|
||||
eta = (self.total - self.current) / rate if rate > 0 else 0.0
|
||||
pct = 100.0 * self.current / self.total
|
||||
sys.stdout.write(
|
||||
f"\r[{pct:5.1f}%] {self.current}/{self.total} shards"
|
||||
f" | {rate:.1f} shards/s | ETA: {eta:.0f}s"
|
||||
)
|
||||
sys.stdout.flush()
|
||||
|
||||
def close(self) -> None:
|
||||
sys.stdout.write("\n")
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def _make_progress(total: int) -> _PrintProgress:
|
||||
"""Create a progress tracker (tqdm if available, else print-based)."""
|
||||
try:
|
||||
from tqdm import tqdm # type: ignore[import-untyped]
|
||||
|
||||
return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value]
|
||||
except ImportError:
|
||||
return _PrintProgress(total)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard I/O
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _save_shard(
|
||||
shard_id: int,
|
||||
examples: list[dict[str, Any]],
|
||||
shards_dir: Path,
|
||||
) -> Path:
|
||||
"""Save a shard to disk (.pt if torch available, else .json)."""
|
||||
shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
try:
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
path = shards_dir / f"shard_{shard_id:05d}.pt"
|
||||
torch.save(examples, path)
|
||||
except ImportError:
|
||||
path = shards_dir / f"shard_{shard_id:05d}.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(examples, f)
|
||||
return path
|
||||
|
||||
|
||||
def _load_shard(path: Path) -> list[dict[str, Any]]:
|
||||
"""Load a shard from disk (.pt or .json)."""
|
||||
if path.suffix == ".pt":
|
||||
import torch # type: ignore[import-untyped]
|
||||
|
||||
result: list[dict[str, Any]] = torch.load(path, weights_only=False)
|
||||
return result
|
||||
with open(path) as f:
|
||||
result = json.load(f)
|
||||
return result
|
||||
|
||||
|
||||
def _shard_format() -> str:
|
||||
"""Return the shard file extension based on available libraries."""
|
||||
try:
|
||||
import torch # type: ignore[import-untyped] # noqa: F401
|
||||
|
||||
return ".pt"
|
||||
except ImportError:
|
||||
return ".json"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard worker (module-level for pickling)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult:
|
||||
"""Generate a single shard — top-level function for ProcessPoolExecutor."""
|
||||
from solver.datagen.generator import SyntheticAssemblyGenerator
|
||||
|
||||
t0 = time.monotonic()
|
||||
gen = SyntheticAssemblyGenerator(seed=spec.seed)
|
||||
rng = np.random.default_rng(spec.seed + 1)
|
||||
|
||||
tiers = list(spec.complexity_distribution.keys())
|
||||
probs_list = [spec.complexity_distribution[t] for t in tiers]
|
||||
total = sum(probs_list)
|
||||
probs = [p / total for p in probs_list]
|
||||
|
||||
examples: list[dict[str, Any]] = []
|
||||
for i in range(spec.count):
|
||||
tier_idx = int(rng.choice(len(tiers), p=probs))
|
||||
tier = tiers[tier_idx]
|
||||
try:
|
||||
batch = gen.generate_training_batch(
|
||||
batch_size=1,
|
||||
complexity_tier=tier, # type: ignore[arg-type]
|
||||
grounded_ratio=spec.grounded_ratio,
|
||||
)
|
||||
ex = batch[0]
|
||||
ex["example_id"] = spec.start_example_id + i
|
||||
ex["complexity_tier"] = tier
|
||||
examples.append(ex)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"Shard %d, example %d failed — skipping",
|
||||
spec.shard_id,
|
||||
i,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
shards_dir = Path(output_dir) / "shards"
|
||||
path = _save_shard(spec.shard_id, examples, shards_dir)
|
||||
|
||||
elapsed = time.monotonic() - t0
|
||||
return ShardResult(
|
||||
shard_id=spec.shard_id,
|
||||
num_examples=len(examples),
|
||||
file_path=str(path),
|
||||
generation_time_s=elapsed,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _parse_scalar(value: str) -> int | float | bool | str:
|
||||
"""Parse a YAML scalar value."""
|
||||
# Strip inline comments (space + #)
|
||||
if " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
elif " #" in value:
|
||||
value = value[: value.index(" #")].strip()
|
||||
v = value.strip()
|
||||
if v.lower() in ("true", "yes"):
|
||||
return True
|
||||
if v.lower() in ("false", "no"):
|
||||
return False
|
||||
try:
|
||||
return int(v)
|
||||
except ValueError:
|
||||
pass
|
||||
try:
|
||||
return float(v)
|
||||
except ValueError:
|
||||
pass
|
||||
return v.strip("'\"")
|
||||
|
||||
|
||||
def parse_simple_yaml(path: str) -> dict[str, Any]:
|
||||
"""Parse a simple YAML file (flat scalars, one-level dicts, lists).
|
||||
|
||||
This is **not** a full YAML parser. It handles the structure of
|
||||
``configs/dataset/synthetic.yaml``.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
current_key: str | None = None
|
||||
|
||||
with open(path) as f:
|
||||
for raw_line in f:
|
||||
line = raw_line.rstrip()
|
||||
|
||||
# Skip blank lines and full-line comments
|
||||
if not line or line.lstrip().startswith("#"):
|
||||
continue
|
||||
|
||||
indent = len(line) - len(line.lstrip())
|
||||
|
||||
if indent == 0 and ":" in line:
|
||||
key, _, value = line.partition(":")
|
||||
key = key.strip()
|
||||
value = value.strip()
|
||||
if value:
|
||||
result[key] = _parse_scalar(value)
|
||||
current_key = None
|
||||
else:
|
||||
current_key = key
|
||||
result[key] = {}
|
||||
continue
|
||||
|
||||
if indent > 0 and line.lstrip().startswith("- "):
|
||||
item = line.lstrip()[2:].strip()
|
||||
if current_key is not None:
|
||||
if isinstance(result.get(current_key), dict) and not result[current_key]:
|
||||
result[current_key] = []
|
||||
if isinstance(result.get(current_key), list):
|
||||
result[current_key].append(_parse_scalar(item))
|
||||
continue
|
||||
|
||||
if indent > 0 and ":" in line and current_key is not None:
|
||||
k, _, v = line.partition(":")
|
||||
k = k.strip()
|
||||
v = v.strip()
|
||||
if v:
|
||||
# Strip inline comments
|
||||
if " #" in v:
|
||||
v = v[: v.index(" #")].strip()
|
||||
if not isinstance(result.get(current_key), dict):
|
||||
result[current_key] = {}
|
||||
result[current_key][k] = _parse_scalar(v)
|
||||
continue
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Dataset generator orchestrator
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DatasetGenerator:
|
||||
"""Orchestrates parallel dataset generation with sharding and checkpointing."""
|
||||
|
||||
def __init__(self, config: DatasetConfig) -> None:
|
||||
self.config = config
|
||||
self.output_path = Path(config.output_dir)
|
||||
self.shards_dir = self.output_path / "shards"
|
||||
self.checkpoint_file = self.output_path / ".checkpoint.json"
|
||||
self.index_file = self.output_path / "index.json"
|
||||
self.stats_file = self.output_path / "stats.json"
|
||||
|
||||
# -- public API --
|
||||
|
||||
def run(self) -> None:
|
||||
"""Generate the full dataset."""
|
||||
self.output_path.mkdir(parents=True, exist_ok=True)
|
||||
self.shards_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
shards = self._plan_shards()
|
||||
total_shards = len(shards)
|
||||
|
||||
# Resume: find already-completed shards
|
||||
completed: set[int] = set()
|
||||
if self.config.resume:
|
||||
completed = self._find_completed_shards()
|
||||
|
||||
pending = [s for s in shards if s.shard_id not in completed]
|
||||
|
||||
if not pending:
|
||||
logger.info("All %d shards already complete.", total_shards)
|
||||
else:
|
||||
logger.info(
|
||||
"Generating %d shards (%d already complete).",
|
||||
len(pending),
|
||||
len(completed),
|
||||
)
|
||||
progress = _make_progress(len(pending))
|
||||
workers = max(1, self.config.num_workers)
|
||||
checkpoint_counter = 0
|
||||
|
||||
with ProcessPoolExecutor(max_workers=workers) as pool:
|
||||
futures = {
|
||||
pool.submit(
|
||||
_generate_shard_worker,
|
||||
spec,
|
||||
str(self.output_path),
|
||||
): spec.shard_id
|
||||
for spec in pending
|
||||
}
|
||||
for future in as_completed(futures):
|
||||
shard_id = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
completed.add(result.shard_id)
|
||||
logger.debug(
|
||||
"Shard %d: %d examples in %.1fs",
|
||||
result.shard_id,
|
||||
result.num_examples,
|
||||
result.generation_time_s,
|
||||
)
|
||||
except Exception:
|
||||
logger.error("Shard %d failed", shard_id, exc_info=True)
|
||||
progress.update(1)
|
||||
checkpoint_counter += 1
|
||||
if checkpoint_counter >= self.config.checkpoint_every:
|
||||
self._update_checkpoint(completed, total_shards)
|
||||
checkpoint_counter = 0
|
||||
|
||||
progress.close()
|
||||
|
||||
# Finalize
|
||||
self._build_index()
|
||||
stats = self._compute_statistics()
|
||||
self._write_statistics(stats)
|
||||
self._print_summary(stats)
|
||||
|
||||
# Remove checkpoint (generation complete)
|
||||
if self.checkpoint_file.exists():
|
||||
self.checkpoint_file.unlink()
|
||||
|
||||
# -- internal helpers --
|
||||
|
||||
def _plan_shards(self) -> list[ShardSpec]:
|
||||
"""Divide num_assemblies into shards."""
|
||||
n = self.config.num_assemblies
|
||||
size = self.config.shard_size
|
||||
num_shards = math.ceil(n / size)
|
||||
shards: list[ShardSpec] = []
|
||||
for i in range(num_shards):
|
||||
start = i * size
|
||||
count = min(size, n - start)
|
||||
shards.append(
|
||||
ShardSpec(
|
||||
shard_id=i,
|
||||
start_example_id=start,
|
||||
count=count,
|
||||
seed=_derive_shard_seed(self.config.seed, i),
|
||||
complexity_distribution=dict(self.config.complexity_distribution),
|
||||
body_count_min=self.config.body_count_min,
|
||||
body_count_max=self.config.body_count_max,
|
||||
grounded_ratio=self.config.grounded_ratio,
|
||||
)
|
||||
)
|
||||
return shards
|
||||
|
||||
def _find_completed_shards(self) -> set[int]:
|
||||
"""Scan shards directory for existing shard files."""
|
||||
completed: set[int] = set()
|
||||
if not self.shards_dir.exists():
|
||||
return completed
|
||||
|
||||
for p in self.shards_dir.iterdir():
|
||||
if p.stem.startswith("shard_"):
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
# Verify file is non-empty
|
||||
if p.stat().st_size > 0:
|
||||
completed.add(shard_id)
|
||||
except (ValueError, IndexError):
|
||||
pass
|
||||
return completed
|
||||
|
||||
def _update_checkpoint(self, completed: set[int], total_shards: int) -> None:
|
||||
"""Write checkpoint file."""
|
||||
data = {
|
||||
"completed_shards": sorted(completed),
|
||||
"total_shards": total_shards,
|
||||
"timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()),
|
||||
}
|
||||
with open(self.checkpoint_file, "w") as f:
|
||||
json.dump(data, f)
|
||||
|
||||
def _build_index(self) -> None:
|
||||
"""Build index.json mapping shard files to assembly ID ranges."""
|
||||
shards_info: dict[str, dict[str, int]] = {}
|
||||
total_assemblies = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
try:
|
||||
shard_id = int(p.stem.split("_")[1])
|
||||
except (ValueError, IndexError):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
count = len(examples)
|
||||
start_id = shard_id * self.config.shard_size
|
||||
shards_info[p.name] = {"start_id": start_id, "count": count}
|
||||
total_assemblies += count
|
||||
|
||||
fmt = _shard_format().lstrip(".")
|
||||
index = {
|
||||
"format_version": 1,
|
||||
"total_assemblies": total_assemblies,
|
||||
"total_shards": len(shards_info),
|
||||
"shard_format": fmt,
|
||||
"shards": shards_info,
|
||||
}
|
||||
with open(self.index_file, "w") as f:
|
||||
json.dump(index, f, indent=2)
|
||||
|
||||
def _compute_statistics(self) -> dict[str, Any]:
|
||||
"""Aggregate statistics across all shards."""
|
||||
classification_counts: dict[str, int] = {}
|
||||
body_count_hist: dict[int, int] = {}
|
||||
joint_type_counts: dict[str, int] = {}
|
||||
dof_values: list[int] = []
|
||||
degeneracy_values: list[int] = []
|
||||
rigid_count = 0
|
||||
minimally_rigid_count = 0
|
||||
total = 0
|
||||
|
||||
for p in sorted(self.shards_dir.iterdir()):
|
||||
if not p.stem.startswith("shard_"):
|
||||
continue
|
||||
examples = _load_shard(p)
|
||||
for ex in examples:
|
||||
total += 1
|
||||
cls = str(ex.get("assembly_classification", "unknown"))
|
||||
classification_counts[cls] = classification_counts.get(cls, 0) + 1
|
||||
nb = int(ex.get("n_bodies", 0))
|
||||
body_count_hist[nb] = body_count_hist.get(nb, 0) + 1
|
||||
for j in ex.get("joints", []):
|
||||
jt = str(j.get("type", "unknown"))
|
||||
joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1
|
||||
dof_values.append(int(ex.get("internal_dof", 0)))
|
||||
degeneracy_values.append(int(ex.get("geometric_degeneracies", 0)))
|
||||
if ex.get("is_rigid"):
|
||||
rigid_count += 1
|
||||
if ex.get("is_minimally_rigid"):
|
||||
minimally_rigid_count += 1
|
||||
|
||||
dof_arr = np.array(dof_values) if dof_values else np.zeros(1)
|
||||
deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1)
|
||||
|
||||
return {
|
||||
"total_examples": total,
|
||||
"classification_distribution": dict(sorted(classification_counts.items())),
|
||||
"body_count_histogram": dict(sorted(body_count_hist.items())),
|
||||
"joint_type_distribution": dict(sorted(joint_type_counts.items())),
|
||||
"dof_statistics": {
|
||||
"mean": float(dof_arr.mean()),
|
||||
"std": float(dof_arr.std()),
|
||||
"min": int(dof_arr.min()),
|
||||
"max": int(dof_arr.max()),
|
||||
"median": float(np.median(dof_arr)),
|
||||
},
|
||||
"geometric_degeneracy": {
|
||||
"assemblies_with_degeneracy": int(np.sum(deg_arr > 0)),
|
||||
"fraction_with_degeneracy": float(np.mean(deg_arr > 0)),
|
||||
"mean_degeneracies": float(deg_arr.mean()),
|
||||
},
|
||||
"rigidity": {
|
||||
"rigid_count": rigid_count,
|
||||
"rigid_fraction": (rigid_count / total if total > 0 else 0.0),
|
||||
"minimally_rigid_count": minimally_rigid_count,
|
||||
"minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0),
|
||||
},
|
||||
}
|
||||
|
||||
def _write_statistics(self, stats: dict[str, Any]) -> None:
|
||||
"""Write stats.json."""
|
||||
with open(self.stats_file, "w") as f:
|
||||
json.dump(stats, f, indent=2)
|
||||
|
||||
def _print_summary(self, stats: dict[str, Any]) -> None:
|
||||
"""Print a human-readable summary to stdout."""
|
||||
print("\n=== Dataset Generation Summary ===")
|
||||
print(f"Total examples: {stats['total_examples']}")
|
||||
print(f"Output directory: {self.output_path}")
|
||||
print()
|
||||
print("Classification distribution:")
|
||||
for cls, count in stats["classification_distribution"].items():
|
||||
frac = count / max(stats["total_examples"], 1) * 100
|
||||
print(f" {cls}: {count} ({frac:.1f}%)")
|
||||
print()
|
||||
print("Joint type distribution:")
|
||||
for jt, count in stats["joint_type_distribution"].items():
|
||||
print(f" {jt}: {count}")
|
||||
print()
|
||||
dof = stats["dof_statistics"]
|
||||
print(
|
||||
f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]"
|
||||
)
|
||||
rig = stats["rigidity"]
|
||||
print(
|
||||
f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} "
|
||||
f"({rig['rigid_fraction'] * 100:.1f}%) rigid, "
|
||||
f"{rig['minimally_rigid_count']} minimally rigid"
|
||||
)
|
||||
deg = stats["geometric_degeneracy"]
|
||||
print(
|
||||
f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies "
|
||||
f"({deg['fraction_with_degeneracy'] * 100:.1f}%)"
|
||||
)
|
||||
337
tests/datagen/test_dataset.py
Normal file
337
tests/datagen/test_dataset.py
Normal file
@@ -0,0 +1,337 @@
|
||||
"""Tests for solver.datagen.dataset — dataset generation orchestration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
_derive_shard_seed,
|
||||
_parse_scalar,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetConfig
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetConfig:
|
||||
"""DatasetConfig construction and defaults."""
|
||||
|
||||
def test_defaults(self) -> None:
|
||||
cfg = DatasetConfig()
|
||||
assert cfg.num_assemblies == 100_000
|
||||
assert cfg.seed == 42
|
||||
assert cfg.shard_size == 1000
|
||||
assert cfg.num_workers == 4
|
||||
|
||||
def test_from_dict_flat(self) -> None:
|
||||
d: dict[str, Any] = {"num_assemblies": 500, "seed": 123}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.num_assemblies == 500
|
||||
assert cfg.seed == 123
|
||||
|
||||
def test_from_dict_nested_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 3
|
||||
assert cfg.body_count_max == 20
|
||||
|
||||
def test_from_dict_flat_body_count(self) -> None:
|
||||
d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.body_count_min == 5
|
||||
assert cfg.body_count_max == 30
|
||||
|
||||
def test_from_dict_complexity_distribution(self) -> None:
|
||||
d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4}
|
||||
|
||||
def test_from_dict_templates(self) -> None:
|
||||
d: dict[str, Any] = {"templates": ["chain", "tree"]}
|
||||
cfg = DatasetConfig.from_dict(d)
|
||||
assert cfg.templates == ["chain", "tree"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Minimal YAML parser
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestParseScalar:
|
||||
"""_parse_scalar handles different value types."""
|
||||
|
||||
def test_int(self) -> None:
|
||||
assert _parse_scalar("42") == 42
|
||||
|
||||
def test_float(self) -> None:
|
||||
assert _parse_scalar("3.14") == 3.14
|
||||
|
||||
def test_bool_true(self) -> None:
|
||||
assert _parse_scalar("true") is True
|
||||
|
||||
def test_bool_false(self) -> None:
|
||||
assert _parse_scalar("false") is False
|
||||
|
||||
def test_string(self) -> None:
|
||||
assert _parse_scalar("hello") == "hello"
|
||||
|
||||
def test_inline_comment(self) -> None:
|
||||
assert _parse_scalar("0.4 # some comment") == 0.4
|
||||
|
||||
|
||||
class TestParseSimpleYaml:
|
||||
"""parse_simple_yaml handles the synthetic.yaml format."""
|
||||
|
||||
def test_flat_scalars(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["name"] == "test"
|
||||
assert result["num"] == 42
|
||||
assert result["ratio"] == 0.5
|
||||
|
||||
def test_nested_dict(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("body_count:\n min: 2\n max: 50\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["body_count"] == {"min": 2, "max": 50}
|
||||
|
||||
def test_list(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["templates"] == ["chain", "tree", "loop"]
|
||||
|
||||
def test_inline_comments(self, tmp_path: Path) -> None:
|
||||
yaml_file = tmp_path / "test.yaml"
|
||||
yaml_file.write_text("dist:\n simple: 0.4 # comment\n")
|
||||
result = parse_simple_yaml(str(yaml_file))
|
||||
assert result["dist"]["simple"] == 0.4
|
||||
|
||||
def test_synthetic_yaml(self) -> None:
|
||||
"""Parse the actual project config."""
|
||||
result = parse_simple_yaml("configs/dataset/synthetic.yaml")
|
||||
assert result["name"] == "synthetic"
|
||||
assert result["num_assemblies"] == 100000
|
||||
assert isinstance(result["complexity_distribution"], dict)
|
||||
assert isinstance(result["templates"], list)
|
||||
assert result["shard_size"] == 1000
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shard seed derivation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShardSeedDerivation:
|
||||
"""_derive_shard_seed is deterministic and unique per shard."""
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 0)
|
||||
assert s1 == s2
|
||||
|
||||
def test_different_shards(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(42, 1)
|
||||
assert s1 != s2
|
||||
|
||||
def test_different_global_seeds(self) -> None:
|
||||
s1 = _derive_shard_seed(42, 0)
|
||||
s2 = _derive_shard_seed(99, 0)
|
||||
assert s1 != s2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# DatasetGenerator — small end-to-end tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDatasetGenerator:
|
||||
"""End-to-end tests with small datasets."""
|
||||
|
||||
def test_small_generation(self, tmp_path: Path) -> None:
|
||||
"""Generate 10 examples in a single shard."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
assert shards_dir.exists()
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 1
|
||||
|
||||
index_file = tmp_path / "output" / "index.json"
|
||||
assert index_file.exists()
|
||||
index = json.loads(index_file.read_text())
|
||||
assert index["total_assemblies"] == 10
|
||||
|
||||
stats_file = tmp_path / "output" / "stats.json"
|
||||
assert stats_file.exists()
|
||||
|
||||
def test_multi_shard(self, tmp_path: Path) -> None:
|
||||
"""Generate 20 examples across 2 shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
shard_files = sorted(shards_dir.glob("shard_*"))
|
||||
assert len(shard_files) == 2
|
||||
|
||||
def test_resume_skips_completed(self, tmp_path: Path) -> None:
|
||||
"""Resume skips already-completed shards."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=20,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Record shard modification times
|
||||
shards_dir = tmp_path / "output" / "shards"
|
||||
mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")}
|
||||
|
||||
# Remove stats (simulate incomplete) and re-run
|
||||
(tmp_path / "output" / "stats.json").unlink()
|
||||
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
# Shards should NOT have been regenerated
|
||||
for p in shards_dir.glob("shard_*"):
|
||||
assert p.stat().st_mtime == mtimes[p.name]
|
||||
|
||||
# Stats should be regenerated
|
||||
assert (tmp_path / "output" / "stats.json").exists()
|
||||
|
||||
def test_checkpoint_removed(self, tmp_path: Path) -> None:
|
||||
"""Checkpoint file is cleaned up after completion."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
checkpoint = tmp_path / "output" / ".checkpoint.json"
|
||||
assert not checkpoint.exists()
|
||||
|
||||
def test_stats_structure(self, tmp_path: Path) -> None:
|
||||
"""stats.json has expected top-level keys."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=10,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
stats = json.loads((tmp_path / "output" / "stats.json").read_text())
|
||||
assert stats["total_examples"] == 10
|
||||
assert "classification_distribution" in stats
|
||||
assert "body_count_histogram" in stats
|
||||
assert "joint_type_distribution" in stats
|
||||
assert "dof_statistics" in stats
|
||||
assert "geometric_degeneracy" in stats
|
||||
assert "rigidity" in stats
|
||||
|
||||
def test_index_structure(self, tmp_path: Path) -> None:
|
||||
"""index.json has expected format."""
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=15,
|
||||
output_dir=str(tmp_path / "output"),
|
||||
shard_size=10,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
index = json.loads((tmp_path / "output" / "index.json").read_text())
|
||||
assert index["format_version"] == 1
|
||||
assert index["total_assemblies"] == 15
|
||||
assert index["total_shards"] == 2
|
||||
assert "shards" in index
|
||||
for _name, info in index["shards"].items():
|
||||
assert "start_id" in info
|
||||
assert "count" in info
|
||||
|
||||
def test_deterministic_output(self, tmp_path: Path) -> None:
|
||||
"""Same seed produces same results."""
|
||||
for run_dir in ("run1", "run2"):
|
||||
cfg = DatasetConfig(
|
||||
num_assemblies=5,
|
||||
output_dir=str(tmp_path / run_dir),
|
||||
shard_size=5,
|
||||
seed=42,
|
||||
num_workers=1,
|
||||
)
|
||||
DatasetGenerator(cfg).run()
|
||||
|
||||
s1 = json.loads((tmp_path / "run1" / "stats.json").read_text())
|
||||
s2 = json.loads((tmp_path / "run2" / "stats.json").read_text())
|
||||
assert s1["total_examples"] == s2["total_examples"]
|
||||
assert s1["classification_distribution"] == s2["classification_distribution"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# CLI integration test
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestCLI:
|
||||
"""Run the script via subprocess."""
|
||||
|
||||
def test_argparse_mode(self, tmp_path: Path) -> None:
|
||||
result = subprocess.run(
|
||||
[
|
||||
sys.executable,
|
||||
"scripts/generate_synthetic.py",
|
||||
"--num-assemblies",
|
||||
"5",
|
||||
"--output-dir",
|
||||
str(tmp_path / "cli_out"),
|
||||
"--shard-size",
|
||||
"5",
|
||||
"--num-workers",
|
||||
"1",
|
||||
"--seed",
|
||||
"42",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd="/home/developer",
|
||||
timeout=120,
|
||||
env={**os.environ, "PYTHONPATH": "/home/developer"},
|
||||
)
|
||||
assert result.returncode == 0, (
|
||||
f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}"
|
||||
)
|
||||
assert (tmp_path / "cli_out" / "index.json").exists()
|
||||
assert (tmp_path / "cli_out" / "stats.json").exists()
|
||||
Reference in New Issue
Block a user