feat(datagen): add dataset generation CLI with sharding and checkpointing
- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator, ShardSpec/ShardResult dataclasses, parallel shard generation via ProcessPoolExecutor, checkpoint/resume support, index and stats output - Add scripts/generate_synthetic.py CLI entry point with Hydra-first and argparse fallback modes - Add minimal YAML parser (parse_simple_yaml) for config loading without PyYAML dependency - Add progress display with tqdm fallback to print-based ETA - Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every - Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator exports - Add tests/datagen/test_dataset.py with 28 tests covering config, YAML parsing, seed derivation, end-to-end generation, resume, stats/index structure, determinism, and CLI integration Closes #10
This commit is contained in:
115
scripts/generate_synthetic.py
Normal file
115
scripts/generate_synthetic.py
Normal file
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Generate synthetic assembly dataset for kindred-solver training.
|
||||
|
||||
Usage (argparse fallback — always available)::
|
||||
|
||||
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
||||
|
||||
Usage (Hydra — when hydra-core is installed)::
|
||||
|
||||
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def _try_hydra_main() -> bool:
|
||||
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
||||
try:
|
||||
import hydra # type: ignore[import-untyped]
|
||||
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
@hydra.main(
|
||||
config_path="../configs/dataset",
|
||||
config_name="synthetic",
|
||||
version_base=None,
|
||||
)
|
||||
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
||||
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
||||
|
||||
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
_run() # type: ignore[no-untyped-call]
|
||||
return True
|
||||
|
||||
|
||||
def _argparse_main() -> None:
|
||||
"""Fallback CLI using argparse."""
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate synthetic assembly dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--config",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Path to YAML config file (optional)",
|
||||
)
|
||||
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
||||
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
||||
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
||||
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
||||
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
||||
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
||||
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
||||
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
||||
parser.add_argument(
|
||||
"--checkpoint-every",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Checkpoint interval (shards)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--no-resume",
|
||||
action="store_true",
|
||||
help="Do not resume from existing checkpoints",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
from solver.datagen.dataset import (
|
||||
DatasetConfig,
|
||||
DatasetGenerator,
|
||||
parse_simple_yaml,
|
||||
)
|
||||
|
||||
config_dict: dict[str, object] = {}
|
||||
if args.config:
|
||||
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
||||
|
||||
# CLI args override config file (only when explicitly provided)
|
||||
_override_map = {
|
||||
"num_assemblies": args.num_assemblies,
|
||||
"output_dir": args.output_dir,
|
||||
"shard_size": args.shard_size,
|
||||
"body_count_min": args.body_count_min,
|
||||
"body_count_max": args.body_count_max,
|
||||
"grounded_ratio": args.grounded_ratio,
|
||||
"seed": args.seed,
|
||||
"num_workers": args.num_workers,
|
||||
"checkpoint_every": args.checkpoint_every,
|
||||
}
|
||||
for key, val in _override_map.items():
|
||||
if val is not None:
|
||||
config_dict[key] = val
|
||||
|
||||
if args.no_resume:
|
||||
config_dict["resume"] = False
|
||||
|
||||
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
||||
DatasetGenerator(config).run()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Entry point: try Hydra first, fall back to argparse."""
|
||||
if not _try_hydra_main():
|
||||
_argparse_main()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user