Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions src/aiida_workgraph/engine/task_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,15 @@ def generate_mapped_tasks(self, zone_task: Task, prefix: str) -> None:
all_links = []
child_tasks = self.get_all_children(zone_task.name)
for child_task in child_tasks:
# The gather_item task is a pure pass-through aggregator
# (executor=return_input); the map zone reads directly from the
# mapped source tasks in `update_map_task_state`, so cloning
# gather_item would just create unused clones. Skipping the
# clone also avoids a race where, for async process-type source
# tasks (CalcJob, WorkChain, @task.graph), the gather_item
# clones stay PLANNED and hang the engine's finalize path.
if self.process.wg.tasks[child_task].identifier == 'workgraph.gather_item':
continue
# since the child task is mapped, it should be skipped
self.state_manager.set_task_runtime_info(child_task, 'state', 'MAPPED')
task = self.copy_task(child_task, prefix)
Expand Down
55 changes: 38 additions & 17 deletions src/aiida_workgraph/engine/task_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,24 +294,45 @@ def update_map_task_state(self, name: str) -> None:
2) gather the results of all the mapped tasks.
3) update the parent task state.
"""
from aiida_workgraph.utils import get_nested_dict

finished, _ = self.are_childen_finished(name)
if finished:
map_zone = self.process.wg.tasks[name]
# gather the results of all the mapped tasks
gather_task = map_zone.gather_item_task
for input in gather_task.inputs:
if input._name.startswith('_'):
continue
results = {}
link = input._links[0]
for prefix, mapped_task in self.process.wg.tasks[gather_task.name].mapped_tasks.items():
results[prefix] = self.ctx._task_results[mapped_task.name][link.to_socket._name]
self.ctx._task_results[name][link.to_socket._name] = results
self.set_task_runtime_info(name, 'state', 'FINISHED')
# self.update_meta_tasks(name)
self.process.report(f'Task: {name} finished.')
self.update_meta_tasks(name)
self.update_parent_task_state(name)
if not finished:
return
map_zone = self.process.wg.tasks[name]
# Gather the results of all the mapped tasks.
#
# We aggregate directly from each mapped SOURCE task (the task whose
# output is linked into the template gather_item), not via the
# gather_item itself. The gather_item template is a pure pass-through
# aggregator (executor=return_input) and is intentionally not cloned
# per item in `generate_mapped_tasks`, so there are no gather_item
# clones to read from. Reading directly from the source's
# `_task_results` is also race-free: the source's results are
# populated by `update_task_state` before any cascade can reach here,
# which matters when the source is an async process-type task
# (CalcJob, WorkChain, or a @task.graph sub-workflow).
gather_task = map_zone.gather_item_task
for input in gather_task.inputs:
if input._name.startswith('_'):
continue
if not input._links:
continue
link = input._links[0]
source_task = self.process.wg.tasks[link.from_task.name]
source_clones = source_task.mapped_tasks or {}
results = {}
for prefix, clone in source_clones.items():
results[prefix] = get_nested_dict(
self.ctx._task_results[clone.name],
link.from_socket._scoped_name,
default=None,
)
self.ctx._task_results[name][link.to_socket._name] = results
self.set_task_runtime_info(name, 'state', 'FINISHED')
self.process.report(f'Task: {name} finished.')
self.update_meta_tasks(name)
self.update_parent_task_state(name)

def update_template_task_state(self, name: str) -> None:
"""Update the template task state.
Expand Down
Loading