Skip to content
85 changes: 72 additions & 13 deletions runtime/src/iree/hal/drivers/amdgpu/aql_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -1690,11 +1690,33 @@ static iree_status_t iree_hal_amdgpu_aql_command_buffer_write_dispatch_tail(
}
return iree_ok_status();
}
case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT:
if (constants.data_length > 0) {
memcpy(tail_payload, constants.data, constants.data_length);
case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT: {
const iree_host_size_t explicit_bytes = layout->explicit_kernarg_size;
const iree_host_size_t copy_bytes =
constants.data_length < explicit_bytes ? constants.data_length
: explicit_bytes;
if (copy_bytes > 0) {
memcpy(tail_payload, constants.data, copy_bytes);
}
if (copy_bytes < explicit_bytes) {
memset(tail_payload + copy_bytes, 0, explicit_bytes - copy_bytes);
}
if (layout->has_implicit_args) {
iree_amdgpu_kernel_implicit_args_t* implicit_args =
(iree_amdgpu_kernel_implicit_args_t*)(tail_payload +
layout->implicit_args_offset);
iree_hal_amdgpu_aql_command_buffer_write_implicit_args(
kernel_args, config, implicit_args);
const iree_host_size_t implicit_args_end =
layout->implicit_args_offset +
IREE_AMDGPU_KERNEL_IMPLICIT_ARGS_SIZE;
if (layout->total_kernarg_size > implicit_args_end) {
memset(tail_payload + implicit_args_end, 0,
layout->total_kernarg_size - implicit_args_end);
}
}
return iree_ok_status();
}
case IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_INDIRECT:
return iree_make_status(
IREE_STATUS_UNIMPLEMENTED,
Expand Down Expand Up @@ -1880,6 +1902,11 @@ typedef struct iree_hal_amdgpu_aql_dispatch_plan_t {
// Kernarg layout selected for HAL or custom-direct arguments.
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* layout;

// Resolved per-dispatch custom-direct layout. Used when descriptor metadata
// leaves the raw kernarg size dynamic and the caller-provided byte length is
// the only exact reservation size.
iree_hal_amdgpu_device_dispatch_kernarg_layout_t custom_layout;

// Number of kernarg blocks required by the selected descriptor path.
uint32_t kernarg_block_count;

Expand Down Expand Up @@ -2017,18 +2044,43 @@ static iree_status_t iree_hal_amdgpu_aql_command_buffer_prepare_dispatch_plan(

if (iree_hal_amdgpu_aql_dispatch_plan_uses_custom_direct_arguments(
out_plan)) {
if (IREE_UNLIKELY(inputs->constants.data_length !=
out_plan->descriptor->kernel_args.kernarg_size)) {
// Callers (e.g. rocBLAS/Tensile) sometimes omit trailing ABI padding or pad
// beyond the declared kernarg_segment_size with extra trailing scalars. The
// kernel only reads its declared size, so trailing bytes are ignored and the
// memcpy in write_dispatch_tail clamps to the declared size.
//
// Validate after 8-byte ABI padding so we accept missing tail padding while
// still rejecting truly short pre-packed HIP argument buffers.
const uint32_t required_explicit_bytes =
(uint32_t)out_plan->descriptor->custom_kernarg_layout
.explicit_kernarg_size;
const iree_host_size_t padded_constant_length =
iree_host_align(inputs->constants.data_length, /*alignment=*/8);
if (IREE_UNLIKELY(padded_constant_length < required_explicit_bytes)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"custom dispatch argument length mismatch; expected %u but got "
"%" PRIhsz,
out_plan->descriptor->kernel_args.kernarg_size,
inputs->constants.data_length);
"custom dispatch argument length too short; expected at least %u "
"but got %" PRIhsz " (padded to %" PRIhsz ")",
required_explicit_bytes, inputs->constants.data_length,
padded_constant_length);
}
out_plan->custom_layout = out_plan->descriptor->custom_kernarg_layout;
if (out_plan->custom_layout.total_kernarg_size == 0) {
out_plan->custom_layout.explicit_kernarg_size =
inputs->constants.data_length;
out_plan->custom_layout.total_kernarg_size = inputs->constants.data_length;
}
out_plan->layout = &out_plan->descriptor->custom_kernarg_layout;
out_plan->layout = &out_plan->custom_layout;
out_plan->kernarg_block_count =
iree_max(1u, out_plan->descriptor->custom_kernarg_block_count);
if (out_plan->layout->total_kernarg_size > 0) {
const uint32_t provided_kernarg_block_count =
(uint32_t)iree_host_size_ceil_div(
out_plan->layout->total_kernarg_size,
sizeof(iree_hal_amdgpu_kernarg_block_t));
out_plan->kernarg_block_count =
iree_max(out_plan->kernarg_block_count, provided_kernarg_block_count);
}
out_plan->kernarg_strategy =
IREE_HAL_AMDGPU_COMMAND_BUFFER_KERNARG_STRATEGY_CUSTOM_DIRECT;
return iree_ok_status();
Expand Down Expand Up @@ -2102,14 +2154,21 @@ iree_hal_amdgpu_aql_command_buffer_calculate_dispatch_layout(
? 0
: (iree_host_size_t)plan->kernel_args->binding_count *
sizeof(uint64_t);
const iree_host_size_t tail_byte_length =
plan->layout->total_kernarg_size - binding_bytes;
const iree_host_size_t total_kernarg_size = plan->layout->total_kernarg_size;
if (IREE_UNLIKELY(total_kernarg_size < binding_bytes)) {
return iree_make_status(
IREE_STATUS_INVALID_ARGUMENT,
"dispatch kernarg size %" PRIhsz
" is smaller than binding table size %" PRIhsz,
total_kernarg_size, binding_bytes);
}
const iree_host_size_t tail_byte_length = total_kernarg_size - binding_bytes;
IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_qword_length(
tail_byte_length, "dispatch tail payload",
&out_layout->kernarg.tail_length_qwords,
&out_layout->kernarg.tail_padded_length));
IREE_RETURN_IF_ERROR(iree_hal_amdgpu_aql_command_buffer_qword_length(
plan->layout->total_kernarg_size, "dispatch kernarg",
total_kernarg_size, "dispatch kernarg",
&out_layout->kernarg.total_length_qwords,
&out_layout->kernarg.total_padded_length));
out_layout->kernarg.implicit_args_offset_qwords =
Expand Down
26 changes: 17 additions & 9 deletions runtime/src/iree/hal/drivers/amdgpu/device/dispatch.c
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ void iree_hal_amdgpu_device_dispatch_emplace_packet(
// Dispatch kernarg emission
//===----------------------------------------------------------------------===//

static void iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
void iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
kernel_args,
const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
Expand Down Expand Up @@ -70,7 +70,6 @@ static void iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
void iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
kernel_args,
const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
layout,
const uint64_t* IREE_AMDGPU_RESTRICT binding_ptrs,
Expand All @@ -89,20 +88,29 @@ void iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
iree_amdgpu_memcpy((uint8_t*)kernarg_ptr + binding_bytes, constants,
constant_bytes);
}

iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
kernel_args, workgroup_count, dynamic_workgroup_local_memory, layout,
kernarg_ptr);
}

void iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
layout,
const void* IREE_AMDGPU_RESTRICT custom_kernarg_ptr,
size_t custom_kernarg_length,
void* IREE_AMDGPU_RESTRICT kernarg_ptr) {
if (layout->total_kernarg_size > 0) {
iree_amdgpu_memcpy(kernarg_ptr, custom_kernarg_ptr,
layout->total_kernarg_size);
const size_t total_kernarg_size =
layout->total_kernarg_size ? layout->total_kernarg_size
: custom_kernarg_length;
if (total_kernarg_size > 0) {
iree_amdgpu_memset(kernarg_ptr, 0, total_kernarg_size);
const size_t explicit_bytes =
layout->has_implicit_args
? layout->implicit_args_offset
: total_kernarg_size;
const size_t copy_bytes =
custom_kernarg_length < explicit_bytes ? custom_kernarg_length
: explicit_bytes;
if (copy_bytes > 0) {
iree_amdgpu_memcpy(kernarg_ptr, custom_kernarg_ptr, copy_bytes);
}
}
}

Expand Down
54 changes: 30 additions & 24 deletions runtime/src/iree/hal/drivers/amdgpu/device/dispatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,22 +78,6 @@ iree_hal_amdgpu_device_dispatch_make_hal_kernarg_layout(
};
}

// Returns a custom-direct-argument layout for a raw kernarg blob of
// |kernarg_size| bytes.
//
// The caller owns all packing and padding in the raw argument blob. No implicit
// suffix is synthesized in this mode.
static inline iree_hal_amdgpu_device_dispatch_kernarg_layout_t
iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout(
size_t kernarg_size) {
return (iree_hal_amdgpu_device_dispatch_kernarg_layout_t){
.explicit_kernarg_size = kernarg_size,
.implicit_args_offset = kernarg_size,
.total_kernarg_size = kernarg_size,
.has_implicit_args = false,
};
}

//===----------------------------------------------------------------------===//
// Dispatch Packet/Kernarg Emission
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -152,7 +136,26 @@ void iree_hal_amdgpu_device_dispatch_emplace_packet(
iree_hsa_kernel_dispatch_packet_t* IREE_AMDGPU_RESTRICT dispatch_packet,
void* IREE_AMDGPU_RESTRICT kernarg_ptr);

// Populates HAL ABI kernargs in already-reserved storage.
// Populates the HIP/OpenCL implicit args suffix in already-reserved storage.
//
// This must be called after explicit HAL/custom kernargs have been populated
// whenever |layout->has_implicit_args| is true.
//
// Preconditions:
// - |kernel_args|, |workgroup_count|, |layout|, and |kernarg_ptr| are
// non-NULL.
// - |layout| describes a reservation with an implicit suffix.
// - |kernarg_ptr| points to at least |layout->total_kernarg_size| bytes of
// writable storage.
void iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
kernel_args,
const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
layout,
void* IREE_AMDGPU_RESTRICT kernarg_ptr);

// Populates HAL ABI explicit kernargs in already-reserved storage.
//
// |binding_ptrs| must provide |kernel_args->binding_count| device pointers as
// raw 64-bit values. |constants| must provide
Expand All @@ -169,27 +172,30 @@ void iree_hal_amdgpu_device_dispatch_emplace_packet(
void iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
const iree_hal_amdgpu_device_kernel_args_t* IREE_AMDGPU_RESTRICT
kernel_args,
const uint32_t workgroup_count[3], uint32_t dynamic_workgroup_local_memory,
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
layout,
const uint64_t* IREE_AMDGPU_RESTRICT binding_ptrs,
const uint32_t* IREE_AMDGPU_RESTRICT constants,
void* IREE_AMDGPU_RESTRICT kernarg_ptr);

// Populates custom direct kernargs in already-reserved storage.
// Populates custom direct explicit kernargs in already-reserved storage.
//
// |custom_kernarg_ptr| must provide |layout->total_kernarg_size| bytes in the
// final kernel ABI shape expected by the target kernel.
// |custom_kernarg_ptr| provides up to |layout->total_kernarg_size| bytes in the
// final kernel ABI shape expected by the target kernel. Missing trailing padding
// bytes remain zeroed.
//
// Preconditions:
// - |layout| and |kernarg_ptr| are non-NULL.
// - |layout| was derived with
// iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout.
// - |custom_kernarg_ptr| is non-NULL when |layout->total_kernarg_size| > 0.
// - |layout| describes either a fixed custom-direct reservation with optional
// implicit suffix storage or a dynamic custom-direct reservation where
// |layout->total_kernarg_size == 0| and |custom_kernarg_length| determines
// the reservation size.
// - |custom_kernarg_ptr| is non-NULL when |custom_kernarg_length| > 0.
void iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
const iree_hal_amdgpu_device_dispatch_kernarg_layout_t* IREE_AMDGPU_RESTRICT
layout,
const void* IREE_AMDGPU_RESTRICT custom_kernarg_ptr,
size_t custom_kernarg_length,
void* IREE_AMDGPU_RESTRICT kernarg_ptr);

// Populates the builtin patch dispatch that updates an indirect-parameter
Expand Down
12 changes: 6 additions & 6 deletions runtime/src/iree/hal/drivers/amdgpu/device/dispatch_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -111,9 +111,10 @@ TEST(DispatchTest, EmplaceHalKernargsWritesBindingsConstantsAndImplicitArgs) {
kernargs.fill(0xFD);

iree_hal_amdgpu_device_dispatch_emplace_hal_kernargs(
&kernel_args, workgroup_count,
/*dynamic_workgroup_local_memory=*/13, &layout, bindings, constants,
kernargs.data());
&kernel_args, &layout, bindings, constants, kernargs.data());
iree_hal_amdgpu_device_dispatch_emplace_implicit_args(
&kernel_args, workgroup_count, /*dynamic_workgroup_local_memory=*/13,
&layout, kernargs.data());

const uint64_t* binding_words =
reinterpret_cast<const uint64_t*>(kernargs.data());
Expand Down Expand Up @@ -158,8 +159,7 @@ TEST(DispatchTest, EmplaceHalKernargsWritesBindingsConstantsAndImplicitArgs) {
}

TEST(DispatchTest, EmplaceCustomKernargsCopiesRawBlob) {
iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout =
iree_hal_amdgpu_device_dispatch_make_custom_kernarg_layout(20);
iree_hal_amdgpu_device_dispatch_kernarg_layout_t layout = {};
const std::array<uint8_t, 20> custom_kernargs = {
0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09,
0x0A, 0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13,
Expand All @@ -168,7 +168,7 @@ TEST(DispatchTest, EmplaceCustomKernargsCopiesRawBlob) {
kernargs.fill(0xFD);

iree_hal_amdgpu_device_dispatch_emplace_custom_kernargs(
&layout, custom_kernargs.data(), kernargs.data());
&layout, custom_kernargs.data(), custom_kernargs.size(), kernargs.data());

EXPECT_EQ(std::memcmp(kernargs.data(), custom_kernargs.data(),
custom_kernargs.size()),
Expand Down
Loading