Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
43c0570
Remove nvidia_wheel_versions
charleshofer Nov 12, 2025
bcef89c
Make jaxlib targets visible
charleshofer Nov 12, 2025
733b7bf
hipblas typedef fix
charleshofer Nov 12, 2025
793d312
No GPU fail
charleshofer Nov 13, 2025
e3ad0ec
Wrap HIP inline functions in anonymous namespaces in vendor.h
mminutoli Feb 12, 2026
a831ef2
SWDEV-512768 - Replace hipGetLastError with hipExtGetLastError
dsicarov-amd Jun 10, 2025
58249a4
Add shared utility function get_rocm_version to test_util.py
charleshofer Nov 14, 2025
e587f90
Fix hipSparse CSR algorithm mappings for ROCm 7
phambinhfin Nov 17, 2025
8089947
Fix v_pages quantization and adjust test params for ROCm compatibilit…
phambinhfin Nov 19, 2025
d9e7020
Address LLVM assertion failure due to a multithreaded use. Update .gi…
Arech8 Nov 26, 2025
42a3be6
Add skip of test_is_finite() on Cuda (#565)
Arech8 Nov 26, 2025
544c6d4
Add rocm test requirements file (#570)
AratiGanesh Dec 15, 2025
4673584
Let the unit tests use build.py for setting up Bazel commands for uni…
charleshofer Dec 15, 2025
1c79814
adding abort logic to rocm/jax (#590)
gulsumgudukbay Jan 13, 2026
9b5d708
Skip is_finite tests on ROCm (not in Triton lowering for jax 0.8.0) (…
phambinhfin Jan 14, 2026
82bf13e
Fix shared memory limit check for ROCm in test_dot (#596)
phambinhfin Jan 14, 2026
ad47e17
Fix Numpy signatures test (#598)
magaonka-amd Jan 14, 2026
3b3b31c
fix merge arts
Ruturaj4 Jan 18, 2026
8a9adef
Enable RngShardingTests (#644)
gulsumgudukbay Jan 22, 2026
4eb7473
Enable test_variadic_reduce_window on ROCm (#647)
mminutoli Feb 12, 2026
f360e13
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
81842f4
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
489fcf6
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
c2ea7b4
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
82a1e81
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
2c12a03
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
5c681b1
Enable testMultivariateNormalSingularCovariance on ROCm (#666)
AratiGanesh Jan 28, 2026
3757b64
Update Skip Reason Outputs (#663)
gulsumgudukbay Jan 28, 2026
7ec9fe0
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
411d4fa
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
130ca42
Skip testCudaArrayInterfaceOnNonCudaFails on ROCm platform (#677)
magaonka-amd Jan 29, 2026
837654d
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
2b8c7fe
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
d0a11b3
Skip sparse tests on ROCm due to hipSPARSE issue (#652)
magaonka-amd Jan 23, 2026
3e79165
Update sparse test skip messages in v0.8.2 (#653)
magaonka-amd Jan 23, 2026
9d4fce1
Remove 'mean' from unsupported params for jnp.var (#689)
magaonka-amd Feb 6, 2026
7601b86
Skipping testEighTinyNorm due to hipSolver issues (#697)
AratiGanesh Feb 9, 2026
c21c3b8
Abort detection CI workflow (#688)
gulsumgudukbay Feb 20, 2026
3f5828e
Abort-Detection: Fix halt-for-connection input (#712)
gulsumgudukbay Feb 24, 2026
6283bbc
fix: add rocm_sysdeps/lib to wheel RUNPATH (#737)
WBobby Mar 17, 2026
2422c28
Temporarily disable the cron trigger for the cont. wheel tests workflow
psanal35 Apr 8, 2026
8d4fbef
Add placeholder for nightly benchmark workflow (#768)
psanal35 May 10, 2026
eb58ba2
Rename the benchmarks workflow consistently (#770)
psanal35 May 10, 2026
2483b46
Add ROCm benchmark workflow for MaxText
psanal35 Apr 23, 2026
1d25b87
Resolve latest MaxText Transformer Engine wheel from ROCm MaxText rel…
psanal35 May 10, 2026
24fb105
Load MaxText ROCm benchmark configs and requirements from ROCm/maxtext
psanal35 May 11, 2026
34b2d0f
Revisit ROCm benchmark results and run-manifest collection
psanal35 May 11, 2026
37fe3f7
Revisit ROCm artifact upload to S3 for reusability
psanal35 May 11, 2026
9183f20
Remove TE installation to keep the model lightweight
psanal35 May 13, 2026
182dd2a
Update benchmark target scripts for more generic use cases
psanal35 May 15, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions .github/workflows/benchmark_rocm.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# CI - Benchmark ROCm
#
# This workflow runs the ROCm benchmarks in ROCm team's GHCR containers.
# It can be triggered manually via workflow_dispatch or called by other workflows
# via workflow_call.
#
# It consists of the following job:
# run-benchmarks:
# - Runs in ROCm team's container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest)
# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from S3.
# - Executes the target benchmark scripts at `targets/<target>/run.sh`.
name: CI - Benchmark ROCm
on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-64-8gpu-amd"
options:
- "linux-x86-64-1gpu-amd"
- "linux-x86-64-4gpu-amd"
- "linux-x86-64-8gpu-amd"
python:
description: "Which Python version to use?"
type: choice
default: "3.12"
options:
- "3.11"
- "3.12"
rocm-version:
description: "Which ROCm version to benchmark?"
type: choice
default: "7.2.0"
options:
- "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
target:
description: "Benchmark target"
type: choice
default: "maxtext"
options:
- "maxtext"
workload:
description: "Benchmark workload"
type: string
default: "llama3_8b"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: choice
default: "head"
options:
- "head"
- "pypi_latest"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
type: choice
default: '0'
options:
- '0'
- '1'
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
type: string
default: 'gs://jax-nightly-artifacts/latest'
s3_download_uri:
description: "S3 URI for ROCm plugin/PJRT wheels (use 'latest' to resolve via LATEST pointer)"
type: string
default: 'latest'
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-64-8gpu-amd"
python:
description: "Which Python version to use?"
type: string
default: "3.12"
rocm-version:
description: "Which ROCm version to benchmark?"
type: string
default: "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
target:
description: "Benchmark target"
type: string
default: "maxtext"
workload:
description: "Benchmark workload"
type: string
default: "llama3_8b"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
type: string
default: '0'
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
type: string
default: 'gs://jax-nightly-artifacts/latest'
s3_download_uri:
description: "S3 URI for ROCm plugin/PJRT wheels (use 'latest' to resolve via LATEST pointer)"
type: string
default: 'latest'
permissions:
id-token: write
contents: read

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-benchmarks:
defaults:
run:
# Set the shell to bash as GitHub actions run with /bin/sh by default
shell: bash
runs-on: ${{ inputs.runner }}
continue-on-error: true
# Run in ROCm team's GHCR container with GPU access
container:
image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest # zizmor: ignore[unpinned-images]
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings
name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') ||
(contains(inputs.runner, '4gpu') && '4gpu') ||
(contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}"

env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "0"
INPUT_ROCM_VERSION: "${{ inputs.rocm-version }}"
INPUT_RUNNER: "${{ inputs.runner }}"
INPUT_ROCM_TAG: "${{ inputs.rocm-tag }}"

steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
persist-credentials: false
- name: Download JAX ROCm wheels
uses: ./.github/actions/download-jax-rocm-wheels
with:
python: ${{ inputs.python }}
rocm-version: ${{ inputs.rocm-version }}
jaxlib-version: ${{ inputs.jaxlib-version }}
skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
s3_download_uri: ${{ inputs.s3_download_uri }}
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m pip install uv~=0.11.2
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run ROCm benchmarks
env:
TARGET: ${{ inputs.target }}
WORKLOAD: ${{ inputs.workload }}
timeout-minutes: 120
run: ./ci/benchmark_targets/${TARGET}_rocm/run_${TARGET}_rocm.sh "${WORKLOAD}"
- name: Upload GitHub artifacts
if: always()
continue-on-error: true
uses: actions/upload-artifact@v4
with:
name: benchmark-artifacts-${{ inputs.target }}-${{ inputs.workload }}-${{ inputs.runner }}-py${{ inputs.python }}-rocm${{ inputs.rocm-version }}
path: ci/benchmark_targets/${{ inputs.target }}_rocm/run_artifacts/${{ inputs.workload }}/result.json
if-no-files-found: warn
# - name: Upload CI artifacts to S3
# if: always()
# env:
# S3_BUCKET_NAME: jax-ci-amd
# run: |
# RUN_KEY="$(date -u +%F)_${GITHUB_RUN_ID}_${GITHUB_RUN_ATTEMPT}"
# PREFIX="jax-ci-bench-logs/${GITHUB_REPOSITORY}/${GITHUB_REF_NAME}/${RUN_KEY}"
# ./ci/upload_rocm_artifacts.sh \
# "ci/benchmark_targets/${{ inputs.target }}_rocm/run_artifacts/${{ inputs.workload }}" \
# "${PREFIX}" \
# success
175 changes: 175 additions & 0 deletions .github/workflows/pytest_rocm_abort.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# CI - Pytest ROCm (Abort Support)
#
# This workflow runs the ROCm tests with Pytest in ROCm GHCR containers,
# using the ROCm `pytest-abort` retry wrapper to detect/retry aborts/crashes.
#
# It can be triggered manually via workflow_dispatch or called by other workflows
# via workflow_call.
#
# It consists of the following job:
# run-tests:
# - Runs in ROCm container (ghcr.io/rocm/jax-base-ubu24-rocm*:latest)
# - Downloads the JAX and jaxlib wheels from GCS, and ROCm plugins from latest release.
# - Executes the `run_pytest_rocm_abort.sh` script, which installs wheel artifacts and
# runs the ROCm tests with Pytest under `pytest-abort-retry`.
name: CI - Pytest ROCm (Abort Support)

on:
workflow_dispatch:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: choice
default: "linux-x86-64-4gpu-amd"
options:
- "linux-x86-64-1gpu-amd"
- "linux-x86-64-4gpu-amd"
- "linux-x86-64-8gpu-amd"
python:
description: "Which Python version to use?"
type: choice
default: "3.11"
options:
- "3.11"
- "3.12"
rocm-version:
description: "Which ROCm version to test?"
type: choice
default: "7.2.0"
options:
- "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: choice
default: "head"
options:
- "head"
- "pypi_latest"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
type: choice
default: '0'
options:
- '0'
- '1'
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
type: string
default: 'gs://jax-nightly-artifacts/latest'
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
max-worker-restart:
description: "Max xdist worker restarts (passed to pytest --max-worker-restart)"
type: string
default: '50'
workflow_call:
inputs:
runner:
description: "Which runner should the workflow run on?"
type: string
default: "linux-x86-64-4gpu-amd"
python:
description: "Which Python version to use?"
type: string
default: "3.11"
rocm-version:
description: "Which ROCm version to test?"
type: string
default: "7.2.0"
rocm-tag:
description: "ROCm tag for container image (e.g., rocm720)"
type: string
default: "rocm720"
jaxlib-version:
description: "Which jaxlib version to use? (head/pypi_latest)"
type: string
default: "head"
skip-download-jaxlib-and-plugins-from-gcs:
description: "Whether to skip downloading the jaxlib and plugins from GCS (e.g for testing a jax only release)"
default: '0'
type: string
gcs_download_uri:
description: "GCS location prefix from where the artifacts should be downloaded"
default: 'gs://jax-nightly-artifacts/latest'
type: string
halt-for-connection:
description: 'Should this workflow run wait for a remote connection?'
type: string
default: 'no'
max-worker-restart:
description: "Max xdist worker restarts (passed to pytest --max-worker-restart)"
type: string
default: '50'

permissions: {}

env:
UV_DEFAULT_INDEX: "https://us-python.pkg.dev/ml-oss-artifacts-published/pypi-mirror/simple"

jobs:
run-tests:
defaults:
run:
# Set the shell to bash as GitHub actions run with /bin/sh by default
shell: bash
runs-on: ${{ inputs.runner }}
continue-on-error: true
# Run in ROCm GHCR container with GPU access
container:
image: ghcr.io/rocm/jax-base-ubu24.${{ inputs.rocm-tag }}:latest
credentials:
username: ${{ github.actor }}
password: ${{ secrets.GITHUB_TOKEN }}
options: --device=/dev/kfd --device=/dev/dri --security-opt seccomp=unconfined --group-add video --shm-size 64G --env-file /etc/podinfo/gha-gpu-isolation-settings
name: "${{ (contains(inputs.runner, '1gpu') && '1gpu') ||
(contains(inputs.runner, '4gpu') && '4gpu') ||
(contains(inputs.runner, '8gpu') && '8gpu') }}, ROCm ${{ inputs.rocm-version }}, py${{ inputs.python }}"

env:
JAXCI_HERMETIC_PYTHON_VERSION: "${{ inputs.python }}"
JAXCI_PYTHON: "python${{ inputs.python }}"
JAXCI_ENABLE_X64: "0"
MAX_WORKER_RESTART: "${{ inputs['max-worker-restart'] }}"

steps:
- uses: actions/checkout@1af3b93b6815bc44a9784bd300feb67ff0d1eeb3 # v6.0.0
with:
persist-credentials: false
- name: Download JAX ROCm wheels
uses: ./.github/actions/download-jax-rocm-wheels
with:
python: ${{ inputs.python }}
rocm-version: ${{ inputs.rocm-version }}
jaxlib-version: ${{ inputs.jaxlib-version }}
skip-download-jaxlib-and-plugins-from-gcs: ${{ inputs.skip-download-jaxlib-and-plugins-from-gcs }}
gcs_download_uri: ${{ inputs.gcs_download_uri }}
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Install Python dependencies
run: |
$JAXCI_PYTHON -m pip install uv~=0.5.30
$JAXCI_PYTHON -m uv pip install -r build/test-requirements.txt
# Halt for testing
- name: Wait For Connection
uses: google-ml-infra/actions/ci_connection@7f5ca0c263a81ed09ea276524c1b9192f1304e3c
with:
halt-dispatch-input: ${{ inputs.halt-for-connection }}
- name: Run Pytest ROCm tests (abort support)
timeout-minutes: 180
run: ./ci/run_pytest_rocm_abort.sh
- name: Upload pytest results to artifact
if: always()
uses: actions/upload-artifact@v4
with:
name: logs_abort
path: |
logs_abort/
if-no-files-found: warn
retention-days: 2
overwrite: true
Loading