diff --git a/configs/dataset/synthetic.yaml b/configs/dataset/synthetic.yaml index 8450ad8..5f8c117 100644 --- a/configs/dataset/synthetic.yaml +++ b/configs/dataset/synthetic.yaml @@ -2,6 +2,7 @@ name: synthetic num_assemblies: 100000 output_dir: data/synthetic +shard_size: 1000 complexity_distribution: simple: 0.4 # 2-5 bodies @@ -22,3 +23,4 @@ templates: grounded_ratio: 0.5 seed: 42 num_workers: 4 +checkpoint_every: 5 diff --git a/scripts/generate_synthetic.py b/scripts/generate_synthetic.py new file mode 100644 index 0000000..9a7f99f --- /dev/null +++ b/scripts/generate_synthetic.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python3 +"""Generate synthetic assembly dataset for kindred-solver training. + +Usage (argparse fallback — always available):: + + python scripts/generate_synthetic.py --num-assemblies 1000 --num-workers 4 + +Usage (Hydra — when hydra-core is installed):: + + python scripts/generate_synthetic.py num_assemblies=1000 num_workers=4 +""" + +from __future__ import annotations + + +def _try_hydra_main() -> bool: + """Attempt to run via Hydra. Returns *True* if Hydra handled it.""" + try: + import hydra # type: ignore[import-untyped] + from omegaconf import DictConfig, OmegaConf # type: ignore[import-untyped] + except ImportError: + return False + + @hydra.main( + config_path="../configs/dataset", + config_name="synthetic", + version_base=None, + ) + def _run(cfg: DictConfig) -> None: # type: ignore[type-arg] + from solver.datagen.dataset import DatasetConfig, DatasetGenerator + + config_dict = OmegaConf.to_container(cfg, resolve=True) + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + _run() # type: ignore[no-untyped-call] + return True + + +def _argparse_main() -> None: + """Fallback CLI using argparse.""" + import argparse + + parser = argparse.ArgumentParser( + description="Generate synthetic assembly dataset", + ) + parser.add_argument( + "--config", + type=str, + default=None, + help="Path to YAML config file (optional)", + ) + parser.add_argument("--num-assemblies", type=int, default=None, help="Number of assemblies") + parser.add_argument("--output-dir", type=str, default=None, help="Output directory") + parser.add_argument("--shard-size", type=int, default=None, help="Assemblies per shard") + parser.add_argument("--body-count-min", type=int, default=None, help="Min body count") + parser.add_argument("--body-count-max", type=int, default=None, help="Max body count") + parser.add_argument("--grounded-ratio", type=float, default=None, help="Grounded ratio") + parser.add_argument("--seed", type=int, default=None, help="Random seed") + parser.add_argument("--num-workers", type=int, default=None, help="Parallel workers") + parser.add_argument( + "--checkpoint-every", + type=int, + default=None, + help="Checkpoint interval (shards)", + ) + parser.add_argument( + "--no-resume", + action="store_true", + help="Do not resume from existing checkpoints", + ) + + args = parser.parse_args() + + from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + parse_simple_yaml, + ) + + config_dict: dict[str, object] = {} + if args.config: + config_dict = parse_simple_yaml(args.config) # type: ignore[assignment] + + # CLI args override config file (only when explicitly provided) + _override_map = { + "num_assemblies": args.num_assemblies, + "output_dir": args.output_dir, + "shard_size": args.shard_size, + "body_count_min": args.body_count_min, + "body_count_max": args.body_count_max, + "grounded_ratio": args.grounded_ratio, + "seed": args.seed, + "num_workers": args.num_workers, + "checkpoint_every": args.checkpoint_every, + } + for key, val in _override_map.items(): + if val is not None: + config_dict[key] = val + + if args.no_resume: + config_dict["resume"] = False + + config = DatasetConfig.from_dict(config_dict) # type: ignore[arg-type] + DatasetGenerator(config).run() + + +def main() -> None: + """Entry point: try Hydra first, fall back to argparse.""" + if not _try_hydra_main(): + _argparse_main() + + +if __name__ == "__main__": + main() diff --git a/solver/datagen/__init__.py b/solver/datagen/__init__.py index 169907e..7e26e21 100644 --- a/solver/datagen/__init__.py +++ b/solver/datagen/__init__.py @@ -1,6 +1,7 @@ """Data generation utilities for assembly constraint training data.""" from solver.datagen.analysis import analyze_assembly +from solver.datagen.dataset import DatasetConfig, DatasetGenerator from solver.datagen.generator import ( COMPLEXITY_RANGES, AxisStrategy, @@ -22,6 +23,8 @@ __all__ = [ "AssemblyLabels", "AxisStrategy", "ConstraintAnalysis", + "DatasetConfig", + "DatasetGenerator", "JacobianVerifier", "Joint", "JointType", diff --git a/solver/datagen/dataset.py b/solver/datagen/dataset.py new file mode 100644 index 0000000..e5a4ce2 --- /dev/null +++ b/solver/datagen/dataset.py @@ -0,0 +1,624 @@ +"""Dataset generation orchestrator with sharding, checkpointing, and statistics. + +Provides :class:`DatasetConfig` for configuration and :class:`DatasetGenerator` +for parallel generation of synthetic assembly training data. +""" + +from __future__ import annotations + +import hashlib +import json +import logging +import math +import sys +import time +from concurrent.futures import ProcessPoolExecutor, as_completed +from dataclasses import dataclass, field +from pathlib import Path +from typing import TYPE_CHECKING + +import numpy as np + +if TYPE_CHECKING: + from typing import Any + +__all__ = [ + "DatasetConfig", + "DatasetGenerator", + "parse_simple_yaml", +] + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Configuration +# --------------------------------------------------------------------------- + + +@dataclass +class DatasetConfig: + """Configuration for synthetic dataset generation.""" + + name: str = "synthetic" + num_assemblies: int = 100_000 + output_dir: str = "data/synthetic" + shard_size: int = 1000 + complexity_distribution: dict[str, float] = field( + default_factory=lambda: {"simple": 0.4, "medium": 0.4, "complex": 0.2} + ) + body_count_min: int = 2 + body_count_max: int = 50 + templates: list[str] = field(default_factory=lambda: ["chain", "tree", "loop", "star", "mixed"]) + grounded_ratio: float = 0.5 + seed: int = 42 + num_workers: int = 4 + checkpoint_every: int = 5 + resume: bool = True + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> DatasetConfig: + """Construct from a parsed config dict (e.g. YAML or OmegaConf). + + Handles both flat keys (``body_count_min``) and nested forms + (``body_count: {min: 2, max: 50}``). + """ + kw: dict[str, Any] = {} + for key in ( + "name", + "num_assemblies", + "output_dir", + "shard_size", + "grounded_ratio", + "seed", + "num_workers", + "checkpoint_every", + "resume", + ): + if key in d: + kw[key] = d[key] + + # Handle nested body_count dict + if "body_count" in d and isinstance(d["body_count"], dict): + bc = d["body_count"] + if "min" in bc: + kw["body_count_min"] = int(bc["min"]) + if "max" in bc: + kw["body_count_max"] = int(bc["max"]) + else: + if "body_count_min" in d: + kw["body_count_min"] = int(d["body_count_min"]) + if "body_count_max" in d: + kw["body_count_max"] = int(d["body_count_max"]) + + if "complexity_distribution" in d: + cd = d["complexity_distribution"] + if isinstance(cd, dict): + kw["complexity_distribution"] = {str(k): float(v) for k, v in cd.items()} + if "templates" in d and isinstance(d["templates"], list): + kw["templates"] = [str(t) for t in d["templates"]] + + return cls(**kw) + + +# --------------------------------------------------------------------------- +# Shard specification / result +# --------------------------------------------------------------------------- + + +@dataclass +class ShardSpec: + """Specification for generating a single shard.""" + + shard_id: int + start_example_id: int + count: int + seed: int + complexity_distribution: dict[str, float] + body_count_min: int + body_count_max: int + grounded_ratio: float + + +@dataclass +class ShardResult: + """Result returned from a shard worker.""" + + shard_id: int + num_examples: int + file_path: str + generation_time_s: float + + +# --------------------------------------------------------------------------- +# Seed derivation +# --------------------------------------------------------------------------- + + +def _derive_shard_seed(global_seed: int, shard_id: int) -> int: + """Derive a deterministic per-shard seed from the global seed.""" + h = hashlib.sha256(f"{global_seed}:{shard_id}".encode()).hexdigest() + return int(h[:8], 16) + + +# --------------------------------------------------------------------------- +# Progress display +# --------------------------------------------------------------------------- + + +class _PrintProgress: + """Fallback progress display when tqdm is unavailable.""" + + def __init__(self, total: int) -> None: + self.total = total + self.current = 0 + self.start_time = time.monotonic() + + def update(self, n: int = 1) -> None: + self.current += n + elapsed = time.monotonic() - self.start_time + rate = self.current / elapsed if elapsed > 0 else 0.0 + eta = (self.total - self.current) / rate if rate > 0 else 0.0 + pct = 100.0 * self.current / self.total + sys.stdout.write( + f"\r[{pct:5.1f}%] {self.current}/{self.total} shards" + f" | {rate:.1f} shards/s | ETA: {eta:.0f}s" + ) + sys.stdout.flush() + + def close(self) -> None: + sys.stdout.write("\n") + sys.stdout.flush() + + +def _make_progress(total: int) -> _PrintProgress: + """Create a progress tracker (tqdm if available, else print-based).""" + try: + from tqdm import tqdm # type: ignore[import-untyped] + + return tqdm(total=total, desc="Generating shards", unit="shard") # type: ignore[no-any-return,return-value] + except ImportError: + return _PrintProgress(total) + + +# --------------------------------------------------------------------------- +# Shard I/O +# --------------------------------------------------------------------------- + + +def _save_shard( + shard_id: int, + examples: list[dict[str, Any]], + shards_dir: Path, +) -> Path: + """Save a shard to disk (.pt if torch available, else .json).""" + shards_dir.mkdir(parents=True, exist_ok=True) + try: + import torch # type: ignore[import-untyped] + + path = shards_dir / f"shard_{shard_id:05d}.pt" + torch.save(examples, path) + except ImportError: + path = shards_dir / f"shard_{shard_id:05d}.json" + with open(path, "w") as f: + json.dump(examples, f) + return path + + +def _load_shard(path: Path) -> list[dict[str, Any]]: + """Load a shard from disk (.pt or .json).""" + if path.suffix == ".pt": + import torch # type: ignore[import-untyped] + + result: list[dict[str, Any]] = torch.load(path, weights_only=False) + return result + with open(path) as f: + result = json.load(f) + return result + + +def _shard_format() -> str: + """Return the shard file extension based on available libraries.""" + try: + import torch # type: ignore[import-untyped] # noqa: F401 + + return ".pt" + except ImportError: + return ".json" + + +# --------------------------------------------------------------------------- +# Shard worker (module-level for pickling) +# --------------------------------------------------------------------------- + + +def _generate_shard_worker(spec: ShardSpec, output_dir: str) -> ShardResult: + """Generate a single shard — top-level function for ProcessPoolExecutor.""" + from solver.datagen.generator import SyntheticAssemblyGenerator + + t0 = time.monotonic() + gen = SyntheticAssemblyGenerator(seed=spec.seed) + rng = np.random.default_rng(spec.seed + 1) + + tiers = list(spec.complexity_distribution.keys()) + probs_list = [spec.complexity_distribution[t] for t in tiers] + total = sum(probs_list) + probs = [p / total for p in probs_list] + + examples: list[dict[str, Any]] = [] + for i in range(spec.count): + tier_idx = int(rng.choice(len(tiers), p=probs)) + tier = tiers[tier_idx] + try: + batch = gen.generate_training_batch( + batch_size=1, + complexity_tier=tier, # type: ignore[arg-type] + grounded_ratio=spec.grounded_ratio, + ) + ex = batch[0] + ex["example_id"] = spec.start_example_id + i + ex["complexity_tier"] = tier + examples.append(ex) + except Exception: + logger.warning( + "Shard %d, example %d failed — skipping", + spec.shard_id, + i, + exc_info=True, + ) + + shards_dir = Path(output_dir) / "shards" + path = _save_shard(spec.shard_id, examples, shards_dir) + + elapsed = time.monotonic() - t0 + return ShardResult( + shard_id=spec.shard_id, + num_examples=len(examples), + file_path=str(path), + generation_time_s=elapsed, + ) + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +def _parse_scalar(value: str) -> int | float | bool | str: + """Parse a YAML scalar value.""" + # Strip inline comments (space + #) + if " #" in value: + value = value[: value.index(" #")].strip() + elif " #" in value: + value = value[: value.index(" #")].strip() + v = value.strip() + if v.lower() in ("true", "yes"): + return True + if v.lower() in ("false", "no"): + return False + try: + return int(v) + except ValueError: + pass + try: + return float(v) + except ValueError: + pass + return v.strip("'\"") + + +def parse_simple_yaml(path: str) -> dict[str, Any]: + """Parse a simple YAML file (flat scalars, one-level dicts, lists). + + This is **not** a full YAML parser. It handles the structure of + ``configs/dataset/synthetic.yaml``. + """ + result: dict[str, Any] = {} + current_key: str | None = None + + with open(path) as f: + for raw_line in f: + line = raw_line.rstrip() + + # Skip blank lines and full-line comments + if not line or line.lstrip().startswith("#"): + continue + + indent = len(line) - len(line.lstrip()) + + if indent == 0 and ":" in line: + key, _, value = line.partition(":") + key = key.strip() + value = value.strip() + if value: + result[key] = _parse_scalar(value) + current_key = None + else: + current_key = key + result[key] = {} + continue + + if indent > 0 and line.lstrip().startswith("- "): + item = line.lstrip()[2:].strip() + if current_key is not None: + if isinstance(result.get(current_key), dict) and not result[current_key]: + result[current_key] = [] + if isinstance(result.get(current_key), list): + result[current_key].append(_parse_scalar(item)) + continue + + if indent > 0 and ":" in line and current_key is not None: + k, _, v = line.partition(":") + k = k.strip() + v = v.strip() + if v: + # Strip inline comments + if " #" in v: + v = v[: v.index(" #")].strip() + if not isinstance(result.get(current_key), dict): + result[current_key] = {} + result[current_key][k] = _parse_scalar(v) + continue + + return result + + +# --------------------------------------------------------------------------- +# Dataset generator orchestrator +# --------------------------------------------------------------------------- + + +class DatasetGenerator: + """Orchestrates parallel dataset generation with sharding and checkpointing.""" + + def __init__(self, config: DatasetConfig) -> None: + self.config = config + self.output_path = Path(config.output_dir) + self.shards_dir = self.output_path / "shards" + self.checkpoint_file = self.output_path / ".checkpoint.json" + self.index_file = self.output_path / "index.json" + self.stats_file = self.output_path / "stats.json" + + # -- public API -- + + def run(self) -> None: + """Generate the full dataset.""" + self.output_path.mkdir(parents=True, exist_ok=True) + self.shards_dir.mkdir(parents=True, exist_ok=True) + + shards = self._plan_shards() + total_shards = len(shards) + + # Resume: find already-completed shards + completed: set[int] = set() + if self.config.resume: + completed = self._find_completed_shards() + + pending = [s for s in shards if s.shard_id not in completed] + + if not pending: + logger.info("All %d shards already complete.", total_shards) + else: + logger.info( + "Generating %d shards (%d already complete).", + len(pending), + len(completed), + ) + progress = _make_progress(len(pending)) + workers = max(1, self.config.num_workers) + checkpoint_counter = 0 + + with ProcessPoolExecutor(max_workers=workers) as pool: + futures = { + pool.submit( + _generate_shard_worker, + spec, + str(self.output_path), + ): spec.shard_id + for spec in pending + } + for future in as_completed(futures): + shard_id = futures[future] + try: + result = future.result() + completed.add(result.shard_id) + logger.debug( + "Shard %d: %d examples in %.1fs", + result.shard_id, + result.num_examples, + result.generation_time_s, + ) + except Exception: + logger.error("Shard %d failed", shard_id, exc_info=True) + progress.update(1) + checkpoint_counter += 1 + if checkpoint_counter >= self.config.checkpoint_every: + self._update_checkpoint(completed, total_shards) + checkpoint_counter = 0 + + progress.close() + + # Finalize + self._build_index() + stats = self._compute_statistics() + self._write_statistics(stats) + self._print_summary(stats) + + # Remove checkpoint (generation complete) + if self.checkpoint_file.exists(): + self.checkpoint_file.unlink() + + # -- internal helpers -- + + def _plan_shards(self) -> list[ShardSpec]: + """Divide num_assemblies into shards.""" + n = self.config.num_assemblies + size = self.config.shard_size + num_shards = math.ceil(n / size) + shards: list[ShardSpec] = [] + for i in range(num_shards): + start = i * size + count = min(size, n - start) + shards.append( + ShardSpec( + shard_id=i, + start_example_id=start, + count=count, + seed=_derive_shard_seed(self.config.seed, i), + complexity_distribution=dict(self.config.complexity_distribution), + body_count_min=self.config.body_count_min, + body_count_max=self.config.body_count_max, + grounded_ratio=self.config.grounded_ratio, + ) + ) + return shards + + def _find_completed_shards(self) -> set[int]: + """Scan shards directory for existing shard files.""" + completed: set[int] = set() + if not self.shards_dir.exists(): + return completed + + for p in self.shards_dir.iterdir(): + if p.stem.startswith("shard_"): + try: + shard_id = int(p.stem.split("_")[1]) + # Verify file is non-empty + if p.stat().st_size > 0: + completed.add(shard_id) + except (ValueError, IndexError): + pass + return completed + + def _update_checkpoint(self, completed: set[int], total_shards: int) -> None: + """Write checkpoint file.""" + data = { + "completed_shards": sorted(completed), + "total_shards": total_shards, + "timestamp": time.strftime("%Y-%m-%dT%H:%M:%SZ", time.gmtime()), + } + with open(self.checkpoint_file, "w") as f: + json.dump(data, f) + + def _build_index(self) -> None: + """Build index.json mapping shard files to assembly ID ranges.""" + shards_info: dict[str, dict[str, int]] = {} + total_assemblies = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + try: + shard_id = int(p.stem.split("_")[1]) + except (ValueError, IndexError): + continue + examples = _load_shard(p) + count = len(examples) + start_id = shard_id * self.config.shard_size + shards_info[p.name] = {"start_id": start_id, "count": count} + total_assemblies += count + + fmt = _shard_format().lstrip(".") + index = { + "format_version": 1, + "total_assemblies": total_assemblies, + "total_shards": len(shards_info), + "shard_format": fmt, + "shards": shards_info, + } + with open(self.index_file, "w") as f: + json.dump(index, f, indent=2) + + def _compute_statistics(self) -> dict[str, Any]: + """Aggregate statistics across all shards.""" + classification_counts: dict[str, int] = {} + body_count_hist: dict[int, int] = {} + joint_type_counts: dict[str, int] = {} + dof_values: list[int] = [] + degeneracy_values: list[int] = [] + rigid_count = 0 + minimally_rigid_count = 0 + total = 0 + + for p in sorted(self.shards_dir.iterdir()): + if not p.stem.startswith("shard_"): + continue + examples = _load_shard(p) + for ex in examples: + total += 1 + cls = str(ex.get("assembly_classification", "unknown")) + classification_counts[cls] = classification_counts.get(cls, 0) + 1 + nb = int(ex.get("n_bodies", 0)) + body_count_hist[nb] = body_count_hist.get(nb, 0) + 1 + for j in ex.get("joints", []): + jt = str(j.get("type", "unknown")) + joint_type_counts[jt] = joint_type_counts.get(jt, 0) + 1 + dof_values.append(int(ex.get("internal_dof", 0))) + degeneracy_values.append(int(ex.get("geometric_degeneracies", 0))) + if ex.get("is_rigid"): + rigid_count += 1 + if ex.get("is_minimally_rigid"): + minimally_rigid_count += 1 + + dof_arr = np.array(dof_values) if dof_values else np.zeros(1) + deg_arr = np.array(degeneracy_values) if degeneracy_values else np.zeros(1) + + return { + "total_examples": total, + "classification_distribution": dict(sorted(classification_counts.items())), + "body_count_histogram": dict(sorted(body_count_hist.items())), + "joint_type_distribution": dict(sorted(joint_type_counts.items())), + "dof_statistics": { + "mean": float(dof_arr.mean()), + "std": float(dof_arr.std()), + "min": int(dof_arr.min()), + "max": int(dof_arr.max()), + "median": float(np.median(dof_arr)), + }, + "geometric_degeneracy": { + "assemblies_with_degeneracy": int(np.sum(deg_arr > 0)), + "fraction_with_degeneracy": float(np.mean(deg_arr > 0)), + "mean_degeneracies": float(deg_arr.mean()), + }, + "rigidity": { + "rigid_count": rigid_count, + "rigid_fraction": (rigid_count / total if total > 0 else 0.0), + "minimally_rigid_count": minimally_rigid_count, + "minimally_rigid_fraction": (minimally_rigid_count / total if total > 0 else 0.0), + }, + } + + def _write_statistics(self, stats: dict[str, Any]) -> None: + """Write stats.json.""" + with open(self.stats_file, "w") as f: + json.dump(stats, f, indent=2) + + def _print_summary(self, stats: dict[str, Any]) -> None: + """Print a human-readable summary to stdout.""" + print("\n=== Dataset Generation Summary ===") + print(f"Total examples: {stats['total_examples']}") + print(f"Output directory: {self.output_path}") + print() + print("Classification distribution:") + for cls, count in stats["classification_distribution"].items(): + frac = count / max(stats["total_examples"], 1) * 100 + print(f" {cls}: {count} ({frac:.1f}%)") + print() + print("Joint type distribution:") + for jt, count in stats["joint_type_distribution"].items(): + print(f" {jt}: {count}") + print() + dof = stats["dof_statistics"] + print( + f"DOF: mean={dof['mean']:.1f}, std={dof['std']:.1f}, range=[{dof['min']}, {dof['max']}]" + ) + rig = stats["rigidity"] + print( + f"Rigidity: {rig['rigid_count']}/{stats['total_examples']} " + f"({rig['rigid_fraction'] * 100:.1f}%) rigid, " + f"{rig['minimally_rigid_count']} minimally rigid" + ) + deg = stats["geometric_degeneracy"] + print( + f"Degeneracy: {deg['assemblies_with_degeneracy']} assemblies " + f"({deg['fraction_with_degeneracy'] * 100:.1f}%)" + ) diff --git a/tests/datagen/test_dataset.py b/tests/datagen/test_dataset.py new file mode 100644 index 0000000..398184a --- /dev/null +++ b/tests/datagen/test_dataset.py @@ -0,0 +1,337 @@ +"""Tests for solver.datagen.dataset — dataset generation orchestration.""" + +from __future__ import annotations + +import json +import os +import subprocess +import sys +from typing import TYPE_CHECKING + +from solver.datagen.dataset import ( + DatasetConfig, + DatasetGenerator, + _derive_shard_seed, + _parse_scalar, + parse_simple_yaml, +) + +if TYPE_CHECKING: + from pathlib import Path + from typing import Any + + +# --------------------------------------------------------------------------- +# DatasetConfig +# --------------------------------------------------------------------------- + + +class TestDatasetConfig: + """DatasetConfig construction and defaults.""" + + def test_defaults(self) -> None: + cfg = DatasetConfig() + assert cfg.num_assemblies == 100_000 + assert cfg.seed == 42 + assert cfg.shard_size == 1000 + assert cfg.num_workers == 4 + + def test_from_dict_flat(self) -> None: + d: dict[str, Any] = {"num_assemblies": 500, "seed": 123} + cfg = DatasetConfig.from_dict(d) + assert cfg.num_assemblies == 500 + assert cfg.seed == 123 + + def test_from_dict_nested_body_count(self) -> None: + d: dict[str, Any] = {"body_count": {"min": 3, "max": 20}} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 3 + assert cfg.body_count_max == 20 + + def test_from_dict_flat_body_count(self) -> None: + d: dict[str, Any] = {"body_count_min": 5, "body_count_max": 30} + cfg = DatasetConfig.from_dict(d) + assert cfg.body_count_min == 5 + assert cfg.body_count_max == 30 + + def test_from_dict_complexity_distribution(self) -> None: + d: dict[str, Any] = {"complexity_distribution": {"simple": 0.6, "complex": 0.4}} + cfg = DatasetConfig.from_dict(d) + assert cfg.complexity_distribution == {"simple": 0.6, "complex": 0.4} + + def test_from_dict_templates(self) -> None: + d: dict[str, Any] = {"templates": ["chain", "tree"]} + cfg = DatasetConfig.from_dict(d) + assert cfg.templates == ["chain", "tree"] + + +# --------------------------------------------------------------------------- +# Minimal YAML parser +# --------------------------------------------------------------------------- + + +class TestParseScalar: + """_parse_scalar handles different value types.""" + + def test_int(self) -> None: + assert _parse_scalar("42") == 42 + + def test_float(self) -> None: + assert _parse_scalar("3.14") == 3.14 + + def test_bool_true(self) -> None: + assert _parse_scalar("true") is True + + def test_bool_false(self) -> None: + assert _parse_scalar("false") is False + + def test_string(self) -> None: + assert _parse_scalar("hello") == "hello" + + def test_inline_comment(self) -> None: + assert _parse_scalar("0.4 # some comment") == 0.4 + + +class TestParseSimpleYaml: + """parse_simple_yaml handles the synthetic.yaml format.""" + + def test_flat_scalars(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("name: test\nnum: 42\nratio: 0.5\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["name"] == "test" + assert result["num"] == 42 + assert result["ratio"] == 0.5 + + def test_nested_dict(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("body_count:\n min: 2\n max: 50\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["body_count"] == {"min": 2, "max": 50} + + def test_list(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("templates:\n - chain\n - tree\n - loop\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["templates"] == ["chain", "tree", "loop"] + + def test_inline_comments(self, tmp_path: Path) -> None: + yaml_file = tmp_path / "test.yaml" + yaml_file.write_text("dist:\n simple: 0.4 # comment\n") + result = parse_simple_yaml(str(yaml_file)) + assert result["dist"]["simple"] == 0.4 + + def test_synthetic_yaml(self) -> None: + """Parse the actual project config.""" + result = parse_simple_yaml("configs/dataset/synthetic.yaml") + assert result["name"] == "synthetic" + assert result["num_assemblies"] == 100000 + assert isinstance(result["complexity_distribution"], dict) + assert isinstance(result["templates"], list) + assert result["shard_size"] == 1000 + + +# --------------------------------------------------------------------------- +# Shard seed derivation +# --------------------------------------------------------------------------- + + +class TestShardSeedDerivation: + """_derive_shard_seed is deterministic and unique per shard.""" + + def test_deterministic(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 0) + assert s1 == s2 + + def test_different_shards(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(42, 1) + assert s1 != s2 + + def test_different_global_seeds(self) -> None: + s1 = _derive_shard_seed(42, 0) + s2 = _derive_shard_seed(99, 0) + assert s1 != s2 + + +# --------------------------------------------------------------------------- +# DatasetGenerator — small end-to-end tests +# --------------------------------------------------------------------------- + + +class TestDatasetGenerator: + """End-to-end tests with small datasets.""" + + def test_small_generation(self, tmp_path: Path) -> None: + """Generate 10 examples in a single shard.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + assert shards_dir.exists() + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 1 + + index_file = tmp_path / "output" / "index.json" + assert index_file.exists() + index = json.loads(index_file.read_text()) + assert index["total_assemblies"] == 10 + + stats_file = tmp_path / "output" / "stats.json" + assert stats_file.exists() + + def test_multi_shard(self, tmp_path: Path) -> None: + """Generate 20 examples across 2 shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + shards_dir = tmp_path / "output" / "shards" + shard_files = sorted(shards_dir.glob("shard_*")) + assert len(shard_files) == 2 + + def test_resume_skips_completed(self, tmp_path: Path) -> None: + """Resume skips already-completed shards.""" + cfg = DatasetConfig( + num_assemblies=20, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + # Record shard modification times + shards_dir = tmp_path / "output" / "shards" + mtimes = {p.name: p.stat().st_mtime for p in shards_dir.glob("shard_*")} + + # Remove stats (simulate incomplete) and re-run + (tmp_path / "output" / "stats.json").unlink() + + DatasetGenerator(cfg).run() + + # Shards should NOT have been regenerated + for p in shards_dir.glob("shard_*"): + assert p.stat().st_mtime == mtimes[p.name] + + # Stats should be regenerated + assert (tmp_path / "output" / "stats.json").exists() + + def test_checkpoint_removed(self, tmp_path: Path) -> None: + """Checkpoint file is cleaned up after completion.""" + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / "output"), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + checkpoint = tmp_path / "output" / ".checkpoint.json" + assert not checkpoint.exists() + + def test_stats_structure(self, tmp_path: Path) -> None: + """stats.json has expected top-level keys.""" + cfg = DatasetConfig( + num_assemblies=10, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + stats = json.loads((tmp_path / "output" / "stats.json").read_text()) + assert stats["total_examples"] == 10 + assert "classification_distribution" in stats + assert "body_count_histogram" in stats + assert "joint_type_distribution" in stats + assert "dof_statistics" in stats + assert "geometric_degeneracy" in stats + assert "rigidity" in stats + + def test_index_structure(self, tmp_path: Path) -> None: + """index.json has expected format.""" + cfg = DatasetConfig( + num_assemblies=15, + output_dir=str(tmp_path / "output"), + shard_size=10, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + index = json.loads((tmp_path / "output" / "index.json").read_text()) + assert index["format_version"] == 1 + assert index["total_assemblies"] == 15 + assert index["total_shards"] == 2 + assert "shards" in index + for _name, info in index["shards"].items(): + assert "start_id" in info + assert "count" in info + + def test_deterministic_output(self, tmp_path: Path) -> None: + """Same seed produces same results.""" + for run_dir in ("run1", "run2"): + cfg = DatasetConfig( + num_assemblies=5, + output_dir=str(tmp_path / run_dir), + shard_size=5, + seed=42, + num_workers=1, + ) + DatasetGenerator(cfg).run() + + s1 = json.loads((tmp_path / "run1" / "stats.json").read_text()) + s2 = json.loads((tmp_path / "run2" / "stats.json").read_text()) + assert s1["total_examples"] == s2["total_examples"] + assert s1["classification_distribution"] == s2["classification_distribution"] + + +# --------------------------------------------------------------------------- +# CLI integration test +# --------------------------------------------------------------------------- + + +class TestCLI: + """Run the script via subprocess.""" + + def test_argparse_mode(self, tmp_path: Path) -> None: + result = subprocess.run( + [ + sys.executable, + "scripts/generate_synthetic.py", + "--num-assemblies", + "5", + "--output-dir", + str(tmp_path / "cli_out"), + "--shard-size", + "5", + "--num-workers", + "1", + "--seed", + "42", + ], + capture_output=True, + text=True, + cwd="/home/developer", + timeout=120, + env={**os.environ, "PYTHONPATH": "/home/developer"}, + ) + assert result.returncode == 0, ( + f"CLI failed:\nstdout: {result.stdout}\nstderr: {result.stderr}" + ) + assert (tmp_path / "cli_out" / "index.json").exists() + assert (tmp_path / "cli_out" / "stats.json").exists()