Files
solver/scripts/generate_synthetic.py
forbes-0023 f29060491e
Some checks failed
CI / lint (push) Has been cancelled
CI / type-check (push) Has been cancelled
CI / test (push) Has been cancelled
feat(datagen): add dataset generation CLI with sharding and checkpointing
- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator,
  ShardSpec/ShardResult dataclasses, parallel shard generation via
  ProcessPoolExecutor, checkpoint/resume support, index and stats output
- Add scripts/generate_synthetic.py CLI entry point with Hydra-first
  and argparse fallback modes
- Add minimal YAML parser (parse_simple_yaml) for config loading
  without PyYAML dependency
- Add progress display with tqdm fallback to print-based ETA
- Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every
- Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator
  exports
- Add tests/datagen/test_dataset.py with 28 tests covering config,
  YAML parsing, seed derivation, end-to-end generation, resume,
  stats/index structure, determinism, and CLI integration

Closes #10
2026-02-03 08:44:31 -06:00

116 lines
3.7 KiB
Python

#!/usr/bin/env python3
"""Generate synthetic assembly dataset for kindred-solver training.
Usage (argparse fallback — always available)::
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
Usage (Hydra — when hydra-core is installed)::
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
"""
from __future__ import annotations
def _try_hydra_main() -> bool:
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
try:
import hydra # type: ignore[import-untyped]
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
except ImportError:
return False
@hydra.main(
config_path="../configs/dataset",
config_name="synthetic",
version_base=None,
)
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
config_dict = OmegaConf.to_container(cfg, resolve=True)
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
_run() # type: ignore[no-untyped-call]
return True
def _argparse_main() -> None:
"""Fallback CLI using argparse."""
import argparse
parser = argparse.ArgumentParser(
description="Generate synthetic assembly dataset",
)
parser.add_argument(
"--config",
type=str,
default=None,
help="Path to YAML config file (optional)",
)
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
parser.add_argument("--seed", type=int, default=None, help="Random seed")
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
parser.add_argument(
"--checkpoint-every",
type=int,
default=None,
help="Checkpoint interval (shards)",
)
parser.add_argument(
"--no-resume",
action="store_true",
help="Do not resume from existing checkpoints",
)
args = parser.parse_args()
from solver.datagen.dataset import (
DatasetConfig,
DatasetGenerator,
parse_simple_yaml,
)
config_dict: dict[str, object] = {}
if args.config:
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
# CLI args override config file (only when explicitly provided)
_override_map = {
"num_assemblies": args.num_assemblies,
"output_dir": args.output_dir,
"shard_size": args.shard_size,
"body_count_min": args.body_count_min,
"body_count_max": args.body_count_max,
"grounded_ratio": args.grounded_ratio,
"seed": args.seed,
"num_workers": args.num_workers,
"checkpoint_every": args.checkpoint_every,
}
for key, val in _override_map.items():
if val is not None:
config_dict[key] = val
if args.no_resume:
config_dict["resume"] = False
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
DatasetGenerator(config).run()
def main() -> None:
"""Entry point: try Hydra first, fall back to argparse."""
if not _try_hydra_main():
_argparse_main()
if __name__ == "__main__":
main()