- Add solver/datagen/dataset.py with DatasetConfig, DatasetGenerator, ShardSpec/ShardResult dataclasses, parallel shard generation via ProcessPoolExecutor, checkpoint/resume support, index and stats output - Add scripts/generate_synthetic.py CLI entry point with Hydra-first and argparse fallback modes - Add minimal YAML parser (parse_simple_yaml) for config loading without PyYAML dependency - Add progress display with tqdm fallback to print-based ETA - Update configs/dataset/synthetic.yaml with shard_size, checkpoint_every - Update solver/datagen/__init__.py with DatasetConfig, DatasetGenerator exports - Add tests/datagen/test_dataset.py with 28 tests covering config, YAML parsing, seed derivation, end-to-end generation, resume, stats/index structure, determinism, and CLI integration Closes #10
116 lines
3.7 KiB
Python
116 lines
3.7 KiB
Python
#!/usr/bin/env python3
|
|
"""Generate synthetic assembly dataset for kindred-solver training.
|
|
|
|
Usage (argparse fallback — always available)::
|
|
|
|
python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4
|
|
|
|
Usage (Hydra — when hydra-core is installed)::
|
|
|
|
python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
def _try_hydra_main() -> bool:
|
|
"""Attempt to run via Hydra. Returns *True* if Hydra handled it."""
|
|
try:
|
|
import hydra # type: ignore[import-untyped]
|
|
from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped]
|
|
except ImportError:
|
|
return False
|
|
|
|
@hydra.main(
|
|
config_path="../configs/dataset",
|
|
config_name="synthetic",
|
|
version_base=None,
|
|
)
|
|
def _run(cfg: DictConfig) -> None: # type: ignore[type-arg]
|
|
from solver.datagen.dataset import DatasetConfig, DatasetGenerator
|
|
|
|
config_dict = OmegaConf.to_container(cfg, resolve=True)
|
|
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
|
DatasetGenerator(config).run()
|
|
|
|
_run() # type: ignore[no-untyped-call]
|
|
return True
|
|
|
|
|
|
def _argparse_main() -> None:
|
|
"""Fallback CLI using argparse."""
|
|
import argparse
|
|
|
|
parser = argparse.ArgumentParser(
|
|
description="Generate synthetic assembly dataset",
|
|
)
|
|
parser.add_argument(
|
|
"--config",
|
|
type=str,
|
|
default=None,
|
|
help="Path to YAML config file (optional)",
|
|
)
|
|
parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies")
|
|
parser.add_argument("--output-dir", type=str, default=None, help="Output directory")
|
|
parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard")
|
|
parser.add_argument("--body-count-min", type=int, default=None, help="Min body count")
|
|
parser.add_argument("--body-count-max", type=int, default=None, help="Max body count")
|
|
parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio")
|
|
parser.add_argument("--seed", type=int, default=None, help="Random seed")
|
|
parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers")
|
|
parser.add_argument(
|
|
"--checkpoint-every",
|
|
type=int,
|
|
default=None,
|
|
help="Checkpoint interval (shards)",
|
|
)
|
|
parser.add_argument(
|
|
"--no-resume",
|
|
action="store_true",
|
|
help="Do not resume from existing checkpoints",
|
|
)
|
|
|
|
args = parser.parse_args()
|
|
|
|
from solver.datagen.dataset import (
|
|
DatasetConfig,
|
|
DatasetGenerator,
|
|
parse_simple_yaml,
|
|
)
|
|
|
|
config_dict: dict[str, object] = {}
|
|
if args.config:
|
|
config_dict = parse_simple_yaml(args.config) # type: ignore[assignment]
|
|
|
|
# CLI args override config file (only when explicitly provided)
|
|
_override_map = {
|
|
"num_assemblies": args.num_assemblies,
|
|
"output_dir": args.output_dir,
|
|
"shard_size": args.shard_size,
|
|
"body_count_min": args.body_count_min,
|
|
"body_count_max": args.body_count_max,
|
|
"grounded_ratio": args.grounded_ratio,
|
|
"seed": args.seed,
|
|
"num_workers": args.num_workers,
|
|
"checkpoint_every": args.checkpoint_every,
|
|
}
|
|
for key, val in _override_map.items():
|
|
if val is not None:
|
|
config_dict[key] = val
|
|
|
|
if args.no_resume:
|
|
config_dict["resume"] = False
|
|
|
|
config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type]
|
|
DatasetGenerator(config).run()
|
|
|
|
|
|
def main() -> None:
|
|
"""Entry point: try Hydra first, fall back to argparse."""
|
|
if not _try_hydra_main():
|
|
_argparse_main()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main()
|