Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ def generate_mapped_tasks(self, zone_task: Task, prefix: str) -> None:
all_links = []
child_tasks = self.get_all_children(zone_task.name)
for child_task in child_tasks:
# The gather_item task is a pure pass-through aggregator
# (executor=return_input); the map zone reads directly from the
# mapped source tasks in `update_map_task_state`, so cloning
# gather_item would just create unused clones. Skipping the
# clone also avoids a race where, for async process-type source
# tasks (CalcJob, WorkChain, @task.graph), the gather_item
# clones stay PLANNED and hang the engine's finalize path.
if self.process.wg.tasks[child_task].identifier == 'workgraph.gather_item':
continue
# since the child task is mapped, it should be skipped
self.state_manager.set_task_runtime_info(child_task, 'state', 'MAPPED')
task = self.copy_task(child_task, prefix)
Expand Down
71 changes: 54 additions & 17 deletions src/aiida_workgraph/engine/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,24 +294,61 @@ def update_map_task_state(self, name: str) -> None:
2) gather the results of all the mapped tasks.
3) update the parent task state.
"""
from aiida_workgraph.utils import get_nested_dict

finished, _ = self.are_childen_finished(name)
if finished:
map_zone = self.process.wg.tasks[name]
# gather the results of all the mapped tasks
gather_task = map_zone.gather_item_task
for input in gather_task.inputs:
if input._name.startswith('_'):
continue
results = {}
link = input._links[0]
for prefix, mapped_task in self.process.wg.tasks[gather_task.name].mapped_tasks.items():
results[prefix] = self.ctx._task_results[mapped_task.name][link.to_socket._name]
self.ctx._task_results[name][link.to_socket._name] = results
self.set_task_runtime_info(name, 'state', 'FINISHED')
# self.update_meta_tasks(name)
self.process.report(f'Task: {name} finished.')
self.update_meta_tasks(name)
self.update_parent_task_state(name)
if not finished:
return
map_zone = self.process.wg.tasks[name]
# Gather the results of all the mapped tasks.
#
# We aggregate directly from each mapped SOURCE task (the task whose
# output is linked into the template gather_item), not via the
# gather_item itself. The gather_item template is a pure pass-through
# aggregator (executor=return_input) and is intentionally not cloned
# per item in `generate_mapped_tasks`, so there are no gather_item
# clones to read from. Reading directly from the source's
# `_task_results` is also race-free: the source's results are
# populated by `update_task_state` before any cascade can reach here,
# which matters when the source is an async process-type task
# (CalcJob, WorkChain, or a @task.graph sub-workflow).
gather_task = map_zone.gather_item_task
for input in gather_task.inputs:
if input._name.startswith('_'):
continue
if not input._links:
continue
link = input._links[0]
source_task = self.process.wg.tasks[link.from_task.name]
source_clones = source_task.mapped_tasks or {}
results = {}
for prefix, clone in source_clones.items():
results[prefix] = get_nested_dict(
self.ctx._task_results[clone.name],
link.from_socket._scoped_name,
default=None,
)
self.ctx._task_results[name][link.to_socket._name] = results
# Persist the gathered-result PKs on the process node so the client
# can reconstruct zone outputs after wg.update(). Zone tasks have no
# process node of their own, so without this the client cannot recover
# the per-prefix result nodes.
result_pks: dict = {}
for socket_name, val in self.ctx._task_results[name].items():
if socket_name.startswith('_'):
continue
if isinstance(val, dict):
result_pks[socket_name] = {
prefix: node.pk for prefix, node in val.items() if hasattr(node, 'pk') and node.pk is not None
}
if result_pks:
map_info = self.process.node.get_task_map_info(name) or {}
map_info['result_pks'] = result_pks
self.set_task_runtime_info(name, 'map_info', map_info)
self.set_task_runtime_info(name, 'state', 'FINISHED')
self.process.report(f'Task: {name} finished.')
self.update_meta_tasks(name)
self.update_parent_task_state(name)

def update_template_task_state(self, name: str) -> None:
"""Update the template task state.
Expand Down
8 changes: 7 additions & 1 deletion src/aiida_workgraph/tasks/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,13 @@ def gather(self, sockets: Dict[str, BaseSocket]) -> None:
gather_item = self.gather_item_task
for name in sockets:
gather_item.add_input_spec('workgraph.any', name=name)
self.add_output_spec('workgraph.namespace', name=name)
# The gathered output namespace must be dynamic so the client
# can populate it with per-prefix keys after the run completes.
self.add_output_spec(
'workgraph.namespace',
name=name,
meta=SocketMeta(dynamic=True),
)
gather_item.set_inputs(sockets)
return gather_item.outputs

Expand Down
32 changes: 32 additions & 0 deletions src/aiida_workgraph/workgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,13 +322,45 @@ def update(self) -> None:
continue
self.tasks[name].update_state(data)

# Zone/Map tasks have no process node (pk=None), so update_state
# cannot populate their outputs. The engine persists the gathered
# result PKs in task_map_info; reconstruct outputs from those.
self._populate_zone_outputs(processes_data)

if self.widget is not None:
states = {name: data['state'] for name, data in processes_data.items()}
self.widget.states = states

if self.process.is_finished_ok:
self.outputs._set_socket_value(resolve_node_link_managers(self.process.outputs))

def _populate_zone_outputs(self, processes_data: Dict[str, Any]) -> None:
"""Populate outputs for zone/Map tasks from persisted result PKs.

Zone tasks have no AiiDA process node, so ``Task.update_state`` cannot
load their outputs. The engine persists the gathered result node PKs in
``task_map_info[name]['result_pks']``; this method loads those nodes and
sets them on the corresponding output sockets.
"""
import aiida.orm

for name, data in processes_data.items():
if name not in self.tasks:
continue
# Only handle tasks with no process node that are finished.
if data['pk'] is not None or data['state'] != 'FINISHED':
continue
map_info = self.process.get_task_map_info(name)
if not map_info:
continue
result_pks = map_info.get('result_pks', {})
task = self.tasks[name]
for socket_name, pk_map in result_pks.items():
if socket_name not in task.outputs._sockets:
continue
values = {prefix: aiida.orm.load_node(pk) for prefix, pk in pk_map.items()}
task.outputs[socket_name]._value = values

@property
def pk(self) -> Optional[int]:
return self.process.pk if self.process else None
Expand Down
Loading