Skip to content

Commit c1e6e9b

Browse files
committed
Document dynamic operator wrapper arguments
- Exposes batch_size and device as distinct dynamic operator wrapper parameters so they can be tracked as reserved kwargs. Add generated numpydoc entries for batch_size, device, rng, and seed when those arguments are present in dynamic operator signatures. Signed-off-by: Janusz Lisiecki <jlisiecki@nvidia.com>
1 parent 2842a83 commit c1e6e9b

2 files changed

Lines changed: 36 additions & 9 deletions

File tree

dali/python/nvidia/dali/experimental/dynamic/_op_builder.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -450,8 +450,8 @@ def build_fn_wrapper(op, fn_name=None, add_to_module=True):
450450

451451
fixed_args = []
452452
tensor_args = []
453-
signature_args = ["batch_size=None, device=None"]
454-
used_kwargs = set()
453+
signature_args = ["batch_size=None", "device=None"]
454+
used_kwargs = {"batch_size", "device"}
455455

456456
for arg in op._schema.GetArgumentNames():
457457
if arg in _unsupported_args:
@@ -473,8 +473,9 @@ def build_fn_wrapper(op, fn_name=None, add_to_module=True):
473473
# Remove 'seed' from used_kwargs and signature_args if present
474474
if "seed" in used_kwargs:
475475
used_kwargs.remove("seed")
476-
if "seed" in signature_args:
477-
signature_args.remove("seed")
476+
for arg in signature_args:
477+
if "seed" in arg:
478+
signature_args.remove(arg)
478479

479480
header = f"{fn_name}({', '.join(inputs + signature_args)})"
480481

dali/python/nvidia/dali/ops/_docs.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2023-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
1+
# Copyright (c) 2023-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -114,6 +114,26 @@ def _get_inputs_doc(schema, api):
114114
return ret
115115

116116

117+
def _get_batch_size_doc():
118+
"""Return documentation for the batch_size argument used in dynamic API."""
119+
return _numpydoc_formatter(
120+
"batch_size",
121+
"int",
122+
"The batch size to broadcast input tensors to. Ignored for batch inputs.",
123+
optional=True,
124+
)
125+
126+
127+
def _get_device_doc():
128+
"""Return documentation for the device argument used in dynamic API."""
129+
return _numpydoc_formatter(
130+
"device",
131+
"device-like",
132+
"The device to use for the operation. Must not conflict with the device of the inputs.",
133+
optional=True,
134+
)
135+
136+
117137
def _get_rng_doc():
118138
"""Return documentation for the rng argument used in dynamic API random operators."""
119139
return _numpydoc_formatter(
@@ -190,10 +210,16 @@ def _get_kwargs(schema, api="ops", args=None):
190210
ret += "\n"
191211

192212
# Add rng documentation for dynamic API random operators
193-
if api == "dynamic" and args is not None and "rng" in args:
194-
ret += _get_rng_doc()
195-
ret += "\n"
196-
213+
if api == "dynamic" and args is not None:
214+
if "batch_size" in args:
215+
ret += _get_batch_size_doc()
216+
ret += "\n"
217+
if "device" in args:
218+
ret += _get_device_doc()
219+
ret += "\n"
220+
if "rng" in args:
221+
ret += _get_rng_doc()
222+
ret += "\n"
197223
return ret
198224

199225

0 commit comments

Comments
 (0)