diff --git a/src/aiida_workgraph/engine/task_manager.py b/src/aiida_workgraph/engine/task_manager.py index c6491721..b2ca63aa 100644 --- a/src/aiida_workgraph/engine/task_manager.py +++ b/src/aiida_workgraph/engine/task_manager.py @@ -469,6 +469,15 @@ def generate_mapped_tasks(self, zone_task: Task, prefix: str) -> None: all_links = [] child_tasks = self.get_all_children(zone_task.name) for child_task in child_tasks: + # The gather_item task is a pure pass-through aggregator + # (executor=return_input); the map zone reads directly from the + # mapped source tasks in `update_map_task_state`, so cloning + # gather_item would just create unused clones. Skipping the + # clone also avoids a race where, for async process-type source + # tasks (CalcJob, WorkChain, @task.graph), the gather_item + # clones stay PLANNED and hang the engine's finalize path. + if self.process.wg.tasks[child_task].identifier == 'workgraph.gather_item': + continue # since the child task is mapped, it should be skipped self.state_manager.set_task_runtime_info(child_task, 'state', 'MAPPED') task = self.copy_task(child_task, prefix) diff --git a/src/aiida_workgraph/engine/task_state.py b/src/aiida_workgraph/engine/task_state.py index bbcef9a8..b2f2b2b3 100644 --- a/src/aiida_workgraph/engine/task_state.py +++ b/src/aiida_workgraph/engine/task_state.py @@ -294,24 +294,61 @@ def update_map_task_state(self, name: str) -> None: 2) gather the results of all the mapped tasks. 3) update the parent task state. """ + from aiida_workgraph.utils import get_nested_dict + finished, _ = self.are_childen_finished(name) - if finished: - map_zone = self.process.wg.tasks[name] - # gather the results of all the mapped tasks - gather_task = map_zone.gather_item_task - for input in gather_task.inputs: - if input._name.startswith('_'): - continue - results = {} - link = input._links[0] - for prefix, mapped_task in self.process.wg.tasks[gather_task.name].mapped_tasks.items(): - results[prefix] = self.ctx._task_results[mapped_task.name][link.to_socket._name] - self.ctx._task_results[name][link.to_socket._name] = results - self.set_task_runtime_info(name, 'state', 'FINISHED') - # self.update_meta_tasks(name) - self.process.report(f'Task: {name} finished.') - self.update_meta_tasks(name) - self.update_parent_task_state(name) + if not finished: + return + map_zone = self.process.wg.tasks[name] + # Gather the results of all the mapped tasks. + # + # We aggregate directly from each mapped SOURCE task (the task whose + # output is linked into the template gather_item), not via the + # gather_item itself. The gather_item template is a pure pass-through + # aggregator (executor=return_input) and is intentionally not cloned + # per item in `generate_mapped_tasks`, so there are no gather_item + # clones to read from. Reading directly from the source's + # `_task_results` is also race-free: the source's results are + # populated by `update_task_state` before any cascade can reach here, + # which matters when the source is an async process-type task + # (CalcJob, WorkChain, or a @task.graph sub-workflow). + gather_task = map_zone.gather_item_task + for input in gather_task.inputs: + if input._name.startswith('_'): + continue + if not input._links: + continue + link = input._links[0] + source_task = self.process.wg.tasks[link.from_task.name] + source_clones = source_task.mapped_tasks or {} + results = {} + for prefix, clone in source_clones.items(): + results[prefix] = get_nested_dict( + self.ctx._task_results[clone.name], + link.from_socket._scoped_name, + default=None, + ) + self.ctx._task_results[name][link.to_socket._name] = results + # Persist the gathered-result PKs on the process node so the client + # can reconstruct zone outputs after wg.update(). Zone tasks have no + # process node of their own, so without this the client cannot recover + # the per-prefix result nodes. + result_pks: dict = {} + for socket_name, val in self.ctx._task_results[name].items(): + if socket_name.startswith('_'): + continue + if isinstance(val, dict): + result_pks[socket_name] = { + prefix: node.pk for prefix, node in val.items() if hasattr(node, 'pk') and node.pk is not None + } + if result_pks: + map_info = self.process.node.get_task_map_info(name) or {} + map_info['result_pks'] = result_pks + self.set_task_runtime_info(name, 'map_info', map_info) + self.set_task_runtime_info(name, 'state', 'FINISHED') + self.process.report(f'Task: {name} finished.') + self.update_meta_tasks(name) + self.update_parent_task_state(name) def update_template_task_state(self, name: str) -> None: """Update the template task state. diff --git a/src/aiida_workgraph/tasks/builtins.py b/src/aiida_workgraph/tasks/builtins.py index a7742c30..51d57935 100644 --- a/src/aiida_workgraph/tasks/builtins.py +++ b/src/aiida_workgraph/tasks/builtins.py @@ -128,7 +128,13 @@ def gather(self, sockets: Dict[str, BaseSocket]) -> None: gather_item = self.gather_item_task for name in sockets: gather_item.add_input_spec('workgraph.any', name=name) - self.add_output_spec('workgraph.namespace', name=name) + # The gathered output namespace must be dynamic so the client + # can populate it with per-prefix keys after the run completes. + self.add_output_spec( + 'workgraph.namespace', + name=name, + meta=SocketMeta(dynamic=True), + ) gather_item.set_inputs(sockets) return gather_item.outputs diff --git a/src/aiida_workgraph/workgraph.py b/src/aiida_workgraph/workgraph.py index 60858200..f0009bd1 100644 --- a/src/aiida_workgraph/workgraph.py +++ b/src/aiida_workgraph/workgraph.py @@ -322,6 +322,11 @@ def update(self) -> None: continue self.tasks[name].update_state(data) + # Zone/Map tasks have no process node (pk=None), so update_state + # cannot populate their outputs. The engine persists the gathered + # result PKs in task_map_info; reconstruct outputs from those. + self._populate_zone_outputs(processes_data) + if self.widget is not None: states = {name: data['state'] for name, data in processes_data.items()} self.widget.states = states @@ -329,6 +334,33 @@ def update(self) -> None: if self.process.is_finished_ok: self.outputs._set_socket_value(resolve_node_link_managers(self.process.outputs)) + def _populate_zone_outputs(self, processes_data: Dict[str, Any]) -> None: + """Populate outputs for zone/Map tasks from persisted result PKs. + + Zone tasks have no AiiDA process node, so ``Task.update_state`` cannot + load their outputs. The engine persists the gathered result node PKs in + ``task_map_info[name]['result_pks']``; this method loads those nodes and + sets them on the corresponding output sockets. + """ + import aiida.orm + + for name, data in processes_data.items(): + if name not in self.tasks: + continue + # Only handle tasks with no process node that are finished. + if data['pk'] is not None or data['state'] != 'FINISHED': + continue + map_info = self.process.get_task_map_info(name) + if not map_info: + continue + result_pks = map_info.get('result_pks', {}) + task = self.tasks[name] + for socket_name, pk_map in result_pks.items(): + if socket_name not in task.outputs._sockets: + continue + values = {prefix: aiida.orm.load_node(pk) for prefix, pk in pk_map.items()} + task.outputs[socket_name]._value = values + @property def pk(self) -> Optional[int]: return self.process.pk if self.process else None