diff --git a/docs/source/howto/imperative.ipynb b/docs/source/howto/imperative.ipynb new file mode 100644 index 00000000..3d341972 --- /dev/null +++ b/docs/source/howto/imperative.ipynb @@ -0,0 +1,470 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "bd1eee19", + "metadata": {}, + "source": [ + "# Imperative Workflow\n", + "\n", + "\n", + "## Why an “imperative” API?\n", + "\n", + "The default WorkGraph engines ask you to **declare** every task and dependency *before* execution.\n", + "The *imperative* engine in **aiida‑workgraph** lets you write an ordinary Python `async` function instead. Each time you call a task inside that function the engine immediately schedules it and updates the WorkGraph on‑the‑fly.\n", + "\n", + "*Benefits*\n", + "\n", + "* **Natural control‑flow** – use native `while`, `for`, `if/else`, exceptions.\n", + "* **Incremental graphs** – dependencies are inferred automatically.\n", + "* **Rapid prototyping** – no context variables or DSL constructs required.\n", + "\n", + "If you need full static inspection before running, the declarative API is still available; you can even mix both styles in one project.\n", + "\n", + "## Two mental models for building workflows\n", + "\n", + "\n", + "| Aspect | *WorkGraph “design-time” zone* | *Imperative async function* |\n", + "| :------------------------------------------- | :--------------------------------------------------------------------- | :-------------------------------------------------------------------- |\n", + "| How the graph is produced | User **declares** every task and edge up-front; the full DAG exists before execution. | Graph **emerges at run time** as normal Python control-flow is evaluated. |\n", + "| Can users inspect / validate before running? | Yes – they can traverse, statically analyse, export HTML or visualise the whole DAG. | Not until the first run (the engine can surface partial DAGs on-the-fly). |\n", + "| Expressiveness for loops / branching | Needs explicit constructs (`active_while_zone`, context vars). | Natural Python `while`, `for`, `if`; no extra concepts. |\n", + "| Scheduling & re-runs | Easy for a scheduler to plan because every task is known. | Scheduler must cope with tasks appearing late; re-runs need replay or state capture. |\n", + "| Cognitive load | Higher upfront (DSL, manual dependency wiring). | Lower for Python users; they “just write a function”. |\n", + "| Implementation complexity | Mostly synchronous; no -- or minimal -- `async` knowledge required. | Requires `async`/`await` or background threads so that the runner stays non-blocking. |\n", + "| Performance knobs | Static optimisations (parallelism, caching) are computable before run. | Need adaptive scheduling; late-bound tasks may surprise resource planners. |\n", + "\n", + "## The simplest flow\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "c4f11731", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:18:51 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:18:51 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:18:51 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: add\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': , 'y': }\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:18:52 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:18:52 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|update_normal_task_state]: Task: _flow finished.\n", + "07/03/2025 04:18:52 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:18:52 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: 138144\n", + "07/03/2025 04:18:55 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|update_task_state]: Task: add, type: PythonJob, finished.\n", + "07/03/2025 04:18:55 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: multiply\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': 1, 'y': }\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:18:56 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: 138154\n", + "07/03/2025 04:18:58 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|update_task_state]: Task: multiply, type: PythonJob, finished.\n", + "07/03/2025 04:18:59 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:18:59 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138140|WorkGraphImperativeEngine|finalize]: Finalize workgraph.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "graph results: {'sum': {'socket_name': 'result', 'task_name': 'add'}, 'product': {'socket_name': 'result', 'task_name': 'multiply'}}\n", + "========================================\n", + "\n", + "Results:\n", + "{'sum': , 'product': }\n" + ] + } + ], + "source": [ + "from aiida_workgraph import task\n", + "from aiida_workgraph.engine.imperative.imperative import WorkGraphImperativeEngine\n", + "from aiida.engine import run, submit\n", + "from aiida import orm, load_profile\n", + "\n", + "load_profile()\n", + "\n", + "@task.pythonjob()\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "@task.pythonjob()\n", + "def multiply(x, y):\n", + " return x * y\n", + "\n", + "async def add_multiply(x, y):\n", + " a = add(x, y)\n", + " b = multiply(1, a.result) # chain directly on the result\n", + " return {\"sum\": a.result, \"product\": b.result}\n", + "\n", + "results = run(\n", + " WorkGraphImperativeEngine,\n", + " inputs={\"workgraph_data\": {\n", + " \"name\": \"add_multiply\",\n", + " \"flow\": add_multiply,\n", + " \"function_inputs\": {\"x\": orm.Int(3), \"y\": orm.Int(4)}\n", + " }},\n", + ")\n", + "print(\"=\" * 40)\n", + "print(\"\\nResults:\")\n", + "print(results)\n" + ] + }, + { + "cell_type": "markdown", + "id": "775f8c2a", + "metadata": {}, + "source": [ + "\n", + "### What happens under the hood?\n", + "\n", + "1. The engine creates a fresh WorkGraph called **add\\_multiply**.\n", + "2. `add()` is scheduled; its output node is stored as `a`.\n", + "3. `multiply()` is scheduled as soon as its dependency (`a.result`) is ready.\n", + "4. When the async function returns, the engine marks the WorkGraph **FINISHED** and returns `results`.\n", + "\n", + "> **Tip** You use the GUI to inspect the evolving DAG.\n", + "\n", + "---\n", + "\n", + "## Conditional branches (`if/else`)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5b6a8af0", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:00 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:00 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:19:00 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: add\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': , 'y': }\n", + "node state: RUNNING\n", + "node state: RUNNING\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:07 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|update_task_state]: Task: add, type: PythonJob, finished.\n", + "07/03/2025 04:19:07 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:07 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:19:12 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: multiply\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': -1, 'y': }\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:13 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|update_normal_task_state]: Task: _flow finished.\n", + "07/03/2025 04:19:13 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:13 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: 138175\n", + "07/03/2025 04:19:19 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|update_task_state]: Task: multiply, type: PythonJob, finished.\n", + "07/03/2025 04:19:19 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:20 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138161|WorkGraphImperativeEngine|finalize]: Finalize workgraph.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "graph results: {'sum': {'socket_name': 'result', 'task_name': 'add'}, 'multiply': {'socket_name': 'result', 'task_name': 'multiply'}}\n", + "========================================\n", + "\n", + "Results:\n", + "{'sum': , 'multiply': }\n" + ] + } + ], + "source": [ + "from aiida_workgraph.engine.imperative.imperative import wait_for\n", + "\n", + "@task.pythonjob()\n", + "def add(x, y):\n", + " import time; time.sleep(2) # simulate work\n", + " return x + y\n", + "\n", + "@task.pythonjob()\n", + "def multiply(x, y):\n", + " import time; time.sleep(2)\n", + " return x * y\n", + "\n", + "async def add_then_branch(x, y):\n", + " a = add(x, y)\n", + " await wait_for(a) # don\\'t read the result too early\n", + "\n", + " if a.result.value > 10:\n", + " m = multiply(1, a.result)\n", + " else:\n", + " m = multiply(-1, a.result)\n", + "\n", + " return {\"sum\": a.result, \"multiply\": m.result}\n", + "\n", + "results = run(\n", + " WorkGraphImperativeEngine,\n", + " inputs={\"workgraph_data\": {\n", + " \"name\": \"add_multiply\",\n", + " \"flow\": add_then_branch,\n", + " \"function_inputs\": {\"x\": orm.Int(3), \"y\": orm.Int(4)}\n", + " }},\n", + ")\n", + "print(\"=\" * 40)\n", + "print(\"\\nResults:\")\n", + "print(results)\n" + ] + }, + { + "cell_type": "markdown", + "id": "cda9a8ba", + "metadata": {}, + "source": [ + "\n", + "The `if` statement is ordinary Python; the engine only schedules the branch that is actually taken. The skipped branch never appears in the WorkGraph.\n", + "\n", + "\n", + "## Loops (`while`)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "7d02b8d4", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:21 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:21 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:19:21 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: add\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': , 'y': }\n", + "node state: RUNNING\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:24 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|update_task_state]: Task: add, type: PythonJob, finished.\n", + "07/03/2025 04:19:24 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:24 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:19:27 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: add1\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': , 'y': 1}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:28 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "node state: PLANNED\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:31 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|update_task_state]: Task: add1, type: PythonJob, finished.\n", + "07/03/2025 04:19:31 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: multiply\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "kwargs: {'x': , 'y': 2}\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:32 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow, 138207\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "node state: RUNNING\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "07/03/2025 04:19:35 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|update_task_state]: Task: multiply, type: PythonJob, finished.\n", + "07/03/2025 04:19:35 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:35 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|on_wait]: Process status: Waiting for child processes: _flow\n", + "07/03/2025 04:19:39 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|update_normal_task_state]: Task: _flow finished.\n", + "07/03/2025 04:19:40 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|continue_workgraph]: tasks ready to run: \n", + "07/03/2025 04:19:40 PM <193192> aiida.orm.nodes.process.workflow.workchain.WorkChainNode: [REPORT] [138182|WorkGraphImperativeEngine|finalize]: Finalize workgraph.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "graph results: {'sum': {'socket_name': 'result', 'task_name': 'multiply'}}\n", + "========================================\n", + "\n", + "Results:\n", + "{'sum': }\n" + ] + } + ], + "source": [ + "from aiida_workgraph import task\n", + "from aiida_workgraph.engine.imperative.imperative import WorkGraphImperativeEngine, wait_for\n", + "from aiida.engine import run, submit\n", + "from aiida import orm, load_profile\n", + "\n", + "load_profile()\n", + "\n", + "@task.pythonjob()\n", + "def add(x, y):\n", + " return x + y\n", + "\n", + "@task.pythonjob()\n", + "def multiply(x, y):\n", + " return x * y\n", + "\n", + "\n", + "async def keep_doubling(x, y):\n", + " outputs = add(x, y)\n", + " await wait_for(outputs)\n", + "\n", + " while outputs.result.value < 10:\n", + " outputs1 = add(outputs.result, 1)\n", + " outputs = multiply(outputs1.result, 2)\n", + " await wait_for(outputs) # make sure we can read new value\n", + "\n", + " return {\"sum\": outputs.result}\n", + "\n", + "results = run(\n", + " WorkGraphImperativeEngine,\n", + " inputs={\"workgraph_data\": {\n", + " \"name\": \"keep_doubling\",\n", + " \"flow\": keep_doubling,\n", + " \"function_inputs\": {\"x\": orm.Int(3), \"y\": orm.Int(4)}\n", + " }},\n", + ")\n", + "print(\"=\" * 40)\n", + "print(\"\\nResults:\")\n", + "print(results)" + ] + }, + { + "cell_type": "markdown", + "id": "bfd37a58", + "metadata": {}, + "source": [ + "\n", + "Every iteration adds two more tasks to the WorkGraph. Because the graph grows dynamically you can observe it expanding live.\n", + "\n", + "> **Performance note** Do not schedule thousands of tiny jobs.\n", + "\n", + "\n", + "## Waiting for results\n", + "\n", + "`wait_for(task_socket)` suspends the *flow* until the referenced task has reached a terminal state. This is essential in loops or branches where you need the **value**, not just the future placeholder. Under the hood it performs an asynchronous poll so your runner remains free to execute other coroutines.\n", + "\n", + "---\n", + "\n", + "\n", + "## Conclusion\n", + "\n", + "The imperative API combines the **clarity of Python** with the **provenance guarantees of AiiDA**:\n", + "\n", + "* Write workflows as natural coroutines.\n", + "* Leverage familiar control structures instead of special DSL constructs.\n", + "* Still obtain a fully queryable, shareable WorkGraph.\n", + "\n", + "When you need static analysis or pre‑execution validation drop back to the declarative zones—both styles interoperate because they share the same engine under the hood.\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "aiida", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/source/howto/index.rst b/docs/source/howto/index.rst index 4fae80c9..1cd49890 100644 --- a/docs/source/howto/index.rst +++ b/docs/source/howto/index.rst @@ -29,3 +29,4 @@ This section contains a collection of HowTos for various topics. control transfer_workchain workchain_call_workgraph + imperative diff --git a/src/aiida_workgraph/decorator.py b/src/aiida_workgraph/decorator.py index 7708eaeb..6a53e8a7 100644 --- a/src/aiida_workgraph/decorator.py +++ b/src/aiida_workgraph/decorator.py @@ -80,19 +80,31 @@ def _make_wrapper(TaskCls, func): @functools.wraps(func) def wrapper(*call_args, **call_kwargs): from aiida_workgraph.manager import get_current_graph - - graph = get_current_graph() - if graph is None: - raise RuntimeError(f"No active Graph available for {func.__name__}.") + from aiida.engine.utils import is_process_scoped + from aiida.engine import Process + from aiida_workgraph.engine.imperative.imperative import ( + WorkGraphImperativeEngine, + ) + + is_imperative = False + if is_process_scoped(): + process = Process.current() + if isinstance(process, WorkGraphImperativeEngine): + is_imperative = True + if is_imperative: + process = Process.current() + graph = process.wg + else: + graph = get_current_graph() + if graph is None: + raise RuntimeError(f"No active Graph available for {func.__name__}.") task = graph.add_task(TaskCls) active_zone = getattr(graph, "_active_zone", None) if active_zone: active_zone.children.add(task) - inputs = dict(call_kwargs or {}) arguments = list(call_args) orginal_func = func._func if hasattr(func, "_func") else func - for name, parameter in inspect.signature(orginal_func).parameters.items(): if parameter.kind in [ parameter.POSITIONAL_ONLY, @@ -105,13 +117,26 @@ def wrapper(*call_args, **call_kwargs): elif parameter.kind is parameter.VAR_POSITIONAL: # not supported raise ValueError("VAR_POSITIONAL is not supported.") - task.set(inputs) outputs = [ output for output in task.outputs if output._name not in ["_wait", "_outputs", "exit_code"] ] + if is_imperative: + from aiida_workgraph.utils.analysis import WorkGraphSaver + + workgraph_data = process.wg.prepare_inputs()["workgraph_data"] + saver = WorkGraphSaver( + process.node, + workgraph_data, + restart_process=process._raw_inputs["workgraph_data"].get( + "restart_process", None + ), + ) + graph.connectivity = graph.build_connectivity() + saver.save(update_state=False) + process._do_step() if len(outputs) == 1: return outputs[0] else: diff --git a/src/aiida_workgraph/engine/awaitable_manager.py b/src/aiida_workgraph/engine/awaitable_manager.py index 0f25375e..5fa03c95 100644 --- a/src/aiida_workgraph/engine/awaitable_manager.py +++ b/src/aiida_workgraph/engine/awaitable_manager.py @@ -106,6 +106,21 @@ def action_awaitables(self) -> None: else: assert f"invalid awaitable target '{awaitable.target}'" + def clean_socket_results(self, results) -> None: + """Clean the socket results of the awaitables. + + TODO: this is hardcoded for TaskSocket, we should make it more generic + """ + from aiida_workgraph.socket import TaskSocket, TaskSocketNamespace + + if isinstance(results, dict): + for key, result in results.items(): + results[key] = self.clean_socket_results(result) + elif isinstance(results, (TaskSocket, TaskSocketNamespace)): + # if the result is a TaskSocket, we need to clean it + return {"socket_name": results._name, "task_name": results._node.name} + return results + def on_awaitable_finished(self, awaitable: Awaitable) -> None: """Callback function, for when an awaitable process instance is completed. @@ -157,6 +172,7 @@ def on_awaitable_finished(self, awaitable: Awaitable) -> None: self.process.report(f"Task: {awaitable.key} cancelled.") else: results = awaitable.result() + results = self.clean_socket_results(results) self.process.task_manager.state_manager.update_normal_task_state( awaitable.key, results ) diff --git a/src/aiida_workgraph/engine/imperative/__init__.py b/src/aiida_workgraph/engine/imperative/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aiida_workgraph/engine/imperative/imperative.py b/src/aiida_workgraph/engine/imperative/imperative.py new file mode 100644 index 00000000..0e29ed5f --- /dev/null +++ b/src/aiida_workgraph/engine/imperative/imperative.py @@ -0,0 +1,147 @@ +"""AiiDA workflow components: WorkGraph.""" +from __future__ import annotations + +import typing as t + +from plumpy.persistence import auto_persist +from plumpy.process_states import Continue, Wait +from plumpy.workchains import _PropagateReturn +from aiida.engine.processes.exit_code import ExitCode +from aiida_workgraph.engine.workgraph import WorkGraphEngine, WorkGraphSpec +from aiida_workgraph.socket import TaskSocketNamespace, TaskSocket + +__all__ = "WorkGraph" + + +@auto_persist("_awaitables") +class WorkGraphImperativeEngine(WorkGraphEngine): + """The `WorkGraph` class is used to construct workflows in AiiDA.""" + + def on_create(self) -> None: + """Called when a Process is created.""" + from aiida_workgraph.utils.analysis import WorkGraphSaver + from aiida_workgraph import WorkGraph + + super(WorkGraphEngine, self).on_create() + name = self._raw_inputs[WorkGraphSpec.WORKGRAPH_DATA_KEY]["name"] + flow = self.inputs[WorkGraphSpec.WORKGRAPH_DATA_KEY]["flow"] + self.node.label = name + wg = WorkGraph(name) + wg.add_task(flow, name="_flow") + saver = WorkGraphSaver( + self.node, + wg.prepare_inputs()[WorkGraphSpec.WORKGRAPH_DATA_KEY], + ) + saver.save() + + def execute_flow(self) -> t.Any: + from node_graph.executor import NodeExecutor + import asyncio + + name = "_flow" + task = self.wg.tasks[name] + executor = NodeExecutor(**task.get_executor()).executor + inputs = self.inputs.workgraph_data["function_inputs"] + + awaitable_target = asyncio.ensure_future( + executor(**inputs), + loop=self.loop, + ) + awaitable = self.awaitable_manager.construct_awaitable_function( + name, awaitable_target + ) + self.task_manager.state_manager.set_task_runtime_info(name, "state", "RUNNING") + # save the awaitable to the temp, so that we can kill it if needed + self.awaitable_manager.not_persisted_awaitables[name] = awaitable_target + self.awaitable_manager.to_context(**{name: awaitable}) + + def setup(self) -> None: + """Setup the workgraph engine.""" + super().setup() + + # Execute the flow + self.execute_flow() + + def _do_step(self) -> t.Any: + """Execute the next step in the workgraph and return the result. + + If any awaitables were created, the process will enter in the Wait state, + otherwise it will go to Continue. + """ + result: t.Any = None + + try: + self.task_manager.continue_workgraph() + except _PropagateReturn as exception: + finished, result = True, exception.exit_code + else: + finished, result = self.task_manager.is_workgraph_finished() + + # If the workgraph is finished or the result is an ExitCode, we exit by returning + if finished and len(self._awaitables) == 0: + if isinstance(result, ExitCode): + return result + else: + return self.finalize() + + if self._awaitables: + self.awaitable_manager.action_awaitables() + return Wait(self._do_step, "Waiting before next step") + + return Continue(self._do_step) + + def finalize(self) -> t.Optional[ExitCode]: + """Finalize the workgraph. + Output the results of the workgraph and the new data. + """ + # expose outputs of the workgraph + self.task_manager.state_manager.update_meta_tasks("graph_ctx") + self.wg.update() + graph_results = {} + # print("graph results:", self.ctx._task_results.get("_flow", {})) + for name, data in self.ctx._task_results.get("_flow", {}).items(): + socket = self.wg.tasks[data["task_name"]].outputs[data["socket_name"]] + graph_results[name] = ( + socket._value + if isinstance(socket, TaskSocketNamespace) + else socket.value + ) + self.out_many(graph_results) + # output the new data + if self.ctx._new_data: + self.out("new_data", self.ctx._new_data) + self.report("Finalize workgraph.") + for task in self.wg.tasks: + if ( + self.task_manager.state_manager.get_task_runtime_info( + task.name, "state" + ) + == "FAILED" + ): + return self.exit_codes.TASK_FAILED + + +# --------------------------------------------------------------------------------------- + + +async def wait_for( + socket: TaskSocket, interval: float = 5.0, timeout: float = 604800.0 +) -> None: + """ + Wait for the socket's node to reach a terminal state, with a timeout. + + :param socket: The TaskSocket instance to monitor. + :param interval: How often to check the node state. + :param timeout: Maximum time to wait before raising a TimeoutError. + """ + import time + import asyncio + + start_time = time.monotonic() + while socket._node.state not in ["FINISHED", "FAILED"]: + if time.monotonic() - start_time > timeout: + print("Timeout reached while waiting for node state.") + return + print(f"Task {socket._node.name} state:", socket._node.state) + await asyncio.sleep(interval) + socket._node.graph.update() diff --git a/src/aiida_workgraph/engine/task_state.py b/src/aiida_workgraph/engine/task_state.py index 52d985fe..5bec5ad7 100644 --- a/src/aiida_workgraph/engine/task_state.py +++ b/src/aiida_workgraph/engine/task_state.py @@ -52,6 +52,8 @@ def set_task_runtime_info(self, name: str, key: str, value: Any) -> None: self.process.node.set_task_process(name, serialized) elif key == "state": self.process.node.set_task_state(name, value) + if name in self.process.wg.tasks: + self.process.wg.tasks[name].state = value elif key == "action": self.process.node.set_task_action(name, value) elif key == "execution_count": diff --git a/src/aiida_workgraph/imperative.py b/src/aiida_workgraph/imperative.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_imperative.py b/tests/test_imperative.py new file mode 100644 index 00000000..0bbf09f6 --- /dev/null +++ b/tests/test_imperative.py @@ -0,0 +1,66 @@ +from aiida_workgraph import task +from aiida_workgraph.engine.imperative.imperative import ( + WorkGraphImperativeEngine, + wait_for, +) +from aiida.engine import run +import pytest + + +@task.pythonjob() +def add(x, y): + return x + y + + +@task.pythonjob() +def multiply(x, y): + return x * y + + +@pytest.mark.usefixtures("started_daemon_client") +def test_if(fixture_localhost): + async def if_flow(x, y): + add_result = add(x, y) + await wait_for(add_result) + add_result._node.graph.update() + if add_result.result.value > 10: + multiply_result = multiply(1, add_result.result) + else: + multiply_result = multiply(-1, add_result.result) + return {"sum": add_result.result, "multiply": multiply_result.result} + + results = run( + WorkGraphImperativeEngine, + inputs={ + "workgraph_data": { + "name": "if_flow", + "flow": if_flow, + "function_inputs": {"x": 3, "y": 4}, + } + }, + ) + assert results["sum"].value == 7 + + +@pytest.mark.usefixtures("started_daemon_client") +def test_while(fixture_localhost): + async def add_multiply(x, y): + outputs = add(x, y) + await wait_for(outputs) + while outputs.result.value < 10: + outputs1 = add(outputs.result, 1) + outputs = multiply(outputs1.result, y=2) + await wait_for(outputs) + return {"sum": outputs.result} + + results = run( + WorkGraphImperativeEngine, + inputs={ + "workgraph_data": { + "name": "add_multiply", + "flow": add_multiply, + "function_inputs": {"x": 3, "y": 4}, + } + }, + ) + assert results["sum"].value == 16