diff --git a/src/Mod/CAM/App/tsp_solver_pybind.cpp b/src/Mod/CAM/App/tsp_solver_pybind.cpp index 021bc1ad3c..24a2ddac95 100644 --- a/src/Mod/CAM/App/tsp_solver_pybind.cpp +++ b/src/Mod/CAM/App/tsp_solver_pybind.cpp @@ -108,14 +108,17 @@ std::vector tspSolveTunnelsPy( std::vector cppTunnels; // Convert Python dictionaries to C++ TSPTunnel objects - for (const auto& tunnel : tunnels) { + for (size_t i = 0; i < tunnels.size(); ++i) { + const auto& tunnel = tunnels[i]; double startX = py::cast(tunnel["startX"]); double startY = py::cast(tunnel["startY"]); double endX = py::cast(tunnel["endX"]); double endY = py::cast(tunnel["endY"]); bool isOpen = tunnel.contains("isOpen") ? py::cast(tunnel["isOpen"]) : true; - cppTunnels.emplace_back(startX, startY, endX, endY, isOpen); + TSPTunnel cppTunnel(startX, startY, endX, endY, isOpen); + cppTunnel.index = static_cast(i); + cppTunnels.emplace_back(cppTunnel); } // Handle optional start point @@ -173,10 +176,12 @@ std::vector tspSolveTunnelsPy( // Solve the tunnel TSP auto result = TSPSolver::solveTunnels(cppTunnels, allowFlipping, pStartPoint, pEndPoint); - // Convert result back to Python dictionaries + // Convert result back to Python dictionaries, preserving extra keys from input std::vector pyResult; for (const auto& tunnel : result) { - py::dict tunnelDict; + // Start with a copy of the original input dict to preserve extra keys + py::dict tunnelDict = py::dict(tunnels[tunnel.index]); + // Update with solver results (may have changed due to flipping) tunnelDict["startX"] = tunnel.startX; tunnelDict["startY"] = tunnel.startY; tunnelDict["endX"] = tunnel.endX; diff --git a/src/Mod/CAM/CAMTests/TestTSPSolver.py b/src/Mod/CAM/CAMTests/TestTSPSolver.py index 854b790365..ca908fd363 100644 --- a/src/Mod/CAM/CAMTests/TestTSPSolver.py +++ b/src/Mod/CAM/CAMTests/TestTSPSolver.py @@ -62,6 +62,13 @@ class TestTSPSolver(PathTestBase): f" {i} (orig {orig_idx}): ({tunnel['startX']:.2f},{tunnel['startY']:.2f}) -> ({tunnel['endX']:.2f},{tunnel['endY']:.2f}){flipped_str}" ) + # Print extra data if present + standard_keys = {"startX", "startY", "endX", "endY", "isOpen", "flipped", "index"} + extra_keys = [k for k in tunnel.keys() if k not in standard_keys] + if extra_keys: + extra_data = {k: tunnel[k] for k in extra_keys} + print(f" Extra data: {extra_data}") + def test_01_simple_tsp(self): """Test TSP solver with a simple square of points.""" # Test the TSP solver on a simple square @@ -354,6 +361,83 @@ class TestTSPSolver(PathTestBase): # The route should end at the specified end point # Note: Due to current implementation limitations, this may not be enforced + def test_09_tunnels_extra_data_passthrough(self): + """Test that extra data in tunnel dictionaries is preserved through TSP solving.""" + tunnels = [ + { + "startX": 0, + "startY": 0, + "endX": 5, + "endY": 0, + "tool": "drill_1mm", + "speed": 1000, + "feed": 500, + "custom_id": "tunnel_0", + }, + { + "startX": 20, + "startY": 5, + "endX": 25, + "endY": 5, + "tool": "drill_3mm", + "speed": 600, + "feed": 200, + "notes": "high precision", + "custom_id": "tunnel_2", + }, + { + "startX": 5, + "startY": 17, + "endX": 15, + "endY": 0, + "tool": "mill_2mm", + "speed": 800, + "feed": 300, + "material": "aluminum", + "custom_id": "tunnel_1", + }, + ] + + self.print_tunnels(tunnels, "Input tunnels with extra data") + + # Test with flipping allowed to ensure extra data survives optimization + result = PathUtils.sort_tunnels_tsp(tunnels, allowFlipping=True) + + self.print_tunnels(result, "Sorted tunnels with extra data preserved") + + # Verify all tunnels are present + self.assertEqual(len(result), 3) + + # Verify extra data is preserved for each tunnel + for tunnel in result: + # Check that solver-added keys are present + self.assertIn("startX", tunnel) + self.assertIn("startY", tunnel) + self.assertIn("endX", tunnel) + self.assertIn("endY", tunnel) + self.assertIn("isOpen", tunnel) + self.assertIn("flipped", tunnel) + self.assertIn("index", tunnel) + + # Check that extra keys are preserved + self.assertIn("tool", tunnel) + self.assertIn("speed", tunnel) + self.assertIn("feed", tunnel) + self.assertIn("custom_id", tunnel) + + # Verify specific values based on original index + original_tunnel = tunnels[tunnel["index"]] + self.assertEqual(tunnel["tool"], original_tunnel["tool"]) + self.assertEqual(tunnel["speed"], original_tunnel["speed"]) + self.assertEqual(tunnel["feed"], original_tunnel["feed"]) + self.assertEqual(tunnel["custom_id"], original_tunnel["custom_id"]) + + # Check tunnel-specific extra data + if tunnel["index"] == 2: + self.assertEqual(tunnel["material"], "aluminum") + elif tunnel["index"] == 1: + self.assertEqual(tunnel["notes"], "high precision") + if __name__ == "__main__": import unittest