Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions commit.txt
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm assuming this won't be part of the final PR?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, This is for me to keep track of the issues fixed, will remove this.

Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
Fix cpplint violations in common and PyTorch extension code

transformer_engine/common/amd_detail/hip_float8.h
-Host constructor: multi-statement if/else now uses braces (readability/braces).

transformer_engine/common/cast/mxfp8/rocm_quantize_mxfp8.cuh
-Include <cstdint>; typedef for gfx950 vector type uses int16_t instead of
short (runtime/int).

transformer_engine/common/ck_fused_attn/src/ck_fused_attn_utils.cpp
-dladdr: avoid ill-formed function-pointer-to-void* cast via a small union
(readability/casting / portable POSIX).
-get_ck_log_stream: else branch restructured with nested if so else/brace
pairing satisfies cpplint (readability/braces).

transformer_engine/common/fused_attn_rocm/fused_attn.cpp
-check_set_window_size: replace std::make_pair<int64_t,int64_t>(...) with
std::pair<int64_t,int64_t>(...) (build/explicit_make_pair).
-Replace alternative tokens `or` with || (readability/alt_tokens).
-log_fused_attn_config: same for sliding-window condition.

transformer_engine/common/gemm/rocm_gemm.cu
-ObjCache / NameMapper: mark single-argument constructors explicit
(runtime/explicit).
-HIPBLASLT scaling_mode check: split #if/#else branches so each if has its
own braced body; use static_cast<int> instead of C-style cast
(readability/braces, readability/casting).
-Debug logging: (int) casts -> static_cast<int> for hipDataType fields
(readability/casting).
-ServiceStreamKey: use std::uint64_t alias instead of unsigned long long
(runtime/int).

transformer_engine/common/normalization/common.cpp
-getNormalizationPlan: after optional CUDNN plan, use if (!plan) { ... } for
TE plans instead of } else #endif if (readability/braces across preprocessor).

transformer_engine/common/normalization/layernorm/ln_api.cpp
-Forward/backward: default norm_backend to Te; optional CUDNN path only under
#ifndef __HIP_PLATFORM_AMD__; set is_aligned only when backend is Te, so
preprocessor does not split if/else from its braces (readability/braces).

transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp
-Same pattern as ln_api for forward (including HIP constexpr
gamma_in_weight_dtype) and backward cudnn vs Te (readability/braces).

transformer_engine/common/permutation/permutation.cu
-MoE unpermute kernel: functional-style float(...) casts replaced with
static_cast<float>(...) (readability/casting).

transformer_engine/common/util/logging.h
-NVTE_CHECK_HIPBLASLT macro: std::to_string((int)status) ->
std::to_string(static_cast<int>(status)) (readability/casting).

transformer_engine/pytorch/csrc/extensions/gemm.cpp
-Comm overlap RS path: HIP p2p vs split_overlap_rs restructured with proper
#else for non-HIP so } else #endif { does not confuse brace rules
(readability/braces).
6 changes: 5 additions & 1 deletion transformer_engine/common/amd_detail/hip_float8.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,11 @@ union _te_hip_fp8 {
__device__ operator float() const;

__host__ _te_hip_fp8<FNUZ, OCP>(const float& v) {
if (te_fp8_fnuz()) fnuz=v; else ocp=v;
if (te_fp8_fnuz()) {
fnuz = v;
} else {
ocp = v;
}
}
__device__ _te_hip_fp8<FNUZ, OCP>(const float& v);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
// drop-in replacement for rocm quantize_mxfp8 kernels
//#include "hip/hip_runtime.h" //dummy include to prevent hipification adding this header

#include <cstdint>

constexpr size_t MXFP8_CHUNK_DIM_Y = 64;
constexpr size_t MXFP8_CHUNK_DIM_X = 64;
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64;
Expand All @@ -15,7 +17,7 @@ constexpr size_t ELEMS_PER_THREAD = 16;
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported

#if defined(__gfx950__) && __HIP_DEVICE_COMPILE__
typedef short mxfp8_v2i16_t __attribute__((ext_vector_type(2)));
typedef int16_t mxfp8_v2i16_t __attribute__((ext_vector_type(2)));
#endif

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,13 @@ void set_aiter_asm_dir() {
static std::once_flag aiter_asm_dir_once;
std::call_once(aiter_asm_dir_once, []() {
Dl_info info;
dladdr((void*)set_aiter_asm_dir, &info);
// dladdr expects void*; avoid reinterpret_cast<void*>(fn) (not ISO C++).
union {
void (*fn)();
void *addr;
} sym{};
sym.fn = set_aiter_asm_dir;
dladdr(sym.addr, &info);
Comment on lines +79 to +85
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO this is unnecessary. Yes, it quiets the warning, but the warning is irrelevant for us granted our support is POSIX focused to begin with. From the dlopen man page:

           /* According to the ISO C standard, casting between function
              pointers and 'void *', as done above, produces undefined results.
              POSIX.1-2001 and POSIX.1-2008 accepted this state of affairs and
              proposed the following workaround:

                  *(void **) &cosine = dlsym(handle, "cos");

              This (clumsy) cast conforms with the ISO C standard and will
              avoid any compiler warnings.

              The 2013 Technical Corrigendum 1 to POSIX.1-2008 improved matters
              by requiring that conforming implementations support casting
              'void *' to a function pointer.  Nevertheless, some compilers
              (e.g., gcc with the '-pedantic' option) may complain about the
              cast used in this program.  */

the union trick here provides no additional safety -- it's still undefined behavior technically speaking -- and will break in the same circumstances (non-POSIX risk).

All things considered, I'd rather we keep things as-is, and if we really want to deal with the warning, we can make a small utility to use pragmas to suppress the warnings locally around the cast.

const char* log_ck_config_env = std::getenv("NVTE_LOG_CK_CONFIG");
bool log_ck_config = log_ck_config_env && std::string(log_ck_config_env) == "1";
// Check if user has set AITER_ASM_DIR, if yes, skip auto setting and log
Expand Down Expand Up @@ -130,9 +136,10 @@ std::ostream* get_ck_log_stream() {
if (!log_dir_str.empty() && log_dir_str != "0") {
if (log_dir_str == "1") {
log_stream = &std::cout;
}
else if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) {
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a warning for if-else if? I think it is used a lot in our code

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The category was readability/braces. The message was along the lines of: “If an else has a brace on one side, it should have it on both.”

So the warning showed up because cpplint’s readability/braces heuristic fired on this if / else if layout, not because else if is forbidden. Nesting as else { if (...) { ... } } makes the structure obvious to the linter and cleared the warning.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Move them to single line then but do not create nested ifs

log_stream = &log_file;
} else {
if (open_ck_fused_attn_log_file(log_file, "ck_fused_attn", log_dir_str)) {
log_stream = &log_file;
}
}
}
}
Expand Down
10 changes: 5 additions & 5 deletions transformer_engine/common/fused_attn_rocm/fused_attn.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,26 +146,26 @@ std::pair<int64_t, int64_t> check_set_window_size(NVTE_Mask_Type attn_mask_type,
nvte_log_fused_attn_config = true;
}
if(attn_mask_type==NVTE_CAUSAL_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_MASK || attn_mask_type==NVTE_CAUSAL_BOTTOM_RIGHT_MASK || attn_mask_type==NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK){
if(window_size==std::make_pair<int64_t, int64_t>(-1, -1) || (window_size.first >=0 && window_size.second!=0)){
if(window_size==std::pair<int64_t, int64_t>(-1, -1) || (window_size.first >=0 && window_size.second!=0)){
//TODO: better INFO logging
if(nvte_log_fused_attn_config){
std::cout<<"window_size should be (-1, 0) or (>=0, 0) for attn_mask_type="<<attn_mask_type<<std::endl;
}
window_size.second = 0;
return window_size;
}else if( window_size!=std::make_pair<int64_t, int64_t>(-1, 0) && (window_size.first < 0 || window_size.second != 0)){
}else if( window_size!=std::pair<int64_t, int64_t>(-1, 0) && (window_size.first < 0 || window_size.second != 0)){
NVTE_ERROR("window_size should be (-1, 0) or (>=0, 0) for attn_mask_type=" + std::to_string(attn_mask_type));
}
}else if(attn_mask_type==NVTE_NO_MASK || attn_mask_type==NVTE_PADDING_MASK){
//no_mask and padding mask
if(window_size==std::make_pair<int64_t, int64_t>(-1, 0)){
if(window_size==std::pair<int64_t, int64_t>(-1, 0)){
//TODO: better INFO logging
if(nvte_log_fused_attn_config){
std::cout<<"window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type="<<attn_mask_type<<std::endl;
}
window_size.second=-1;
return window_size;
}else if(window_size!=std::make_pair<int64_t, int64_t>(-1, -1) && (window_size.first < 0 or window_size.second < 0)){
}else if(window_size!=std::pair<int64_t, int64_t>(-1, -1) && (window_size.first < 0 || window_size.second < 0)){
NVTE_ERROR("window_size should be (-1, -1) or (>=0, >=0) for attn_mask_type=" + std::to_string(attn_mask_type));
}
}else{
Expand Down Expand Up @@ -267,7 +267,7 @@ void log_fused_attn_config(
std::cout<<"d_qk: "<<head_dim_qk<<", ";
std::cout<<"d_v: "<<head_dim_v<<", ";
std::cout<<"(window_size_left, window_size_right): ("<<window_size_left<<", "<<window_size_right<<") ";
if(window_size_left >0 or window_size_right >0){
if(window_size_left >0 || window_size_right >0){
std::cout<<", (sliding window)";
}
std::cout<<std::endl;
Expand Down
25 changes: 14 additions & 11 deletions transformer_engine/common/gemm/rocm_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public:
data[key][stream] = item;
}

ObjCache(void (*a_offload)(const Data&)): offload(a_offload) {}
explicit ObjCache(void (*a_offload)(const Data&)): offload(a_offload) {}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we really want these constructors to be explicit? Are they even used implicitly anywhere in our codebase?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

they aren’t used implicitly anywhere; the only constructions are direct ObjCache<T,K>(nullptr) from ObjPool and direct init of service_stream_cache with a lambda. explicit was added only to satisfy cpplint runtime/explicit. If we prefer not to mark these callbacks as explicit, we can drop explicit and suppress that line with NOLINT for cpplint instead.

There are no APIs that take an ObjCache by value. So explicit is not required for correctness, only for style / tooling.


~ObjCache()
{
Expand Down Expand Up @@ -461,7 +461,7 @@ template<typename T>
class NameMapper
{
public:
NameMapper(const std::unordered_map<T, std::string_view>& name_map): map(name_map) {}
explicit NameMapper(const std::unordered_map<T, std::string_view>& name_map): map(name_map) {}
const std::string_view &getName(const T &val) {
return map.at(val);
}
Expand Down Expand Up @@ -769,14 +769,17 @@ protected:
}

#if HIPBLASLT_VERSION_MAJOR > 0 || HIPBLASLT_VERSION_MINOR >= 15
if (cfg.scaling_mode < 0 || cfg.scaling_mode >= (int)HIPBLASLT_MATMUL_MATRIX_SCALE_END)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line length and { at the end of the line are understood but lint does not require duplicate body of If.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the duplicate was to fix the readability/braces warning, I have restructured it to compute a bool in preprocessor-only branches, then use one if body.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like there is discrepancy here. .clang-format sets line width 100 and this is what IDE uses. So cpplint should be configured accorfingly

if (cfg.scaling_mode < 0 ||
cfg.scaling_mode >= static_cast<int>(HIPBLASLT_MATMUL_MATRIX_SCALE_END)) {
std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n";
continue;
}
#else
if (cfg.scaling_mode != 0)
#endif
{
if (cfg.scaling_mode != 0) {
std::cout << "[WARNING] Unsupported scaling mode at " << line << "\n";
continue;
}
#endif

auto fp8_filter = te_fp8_fnuz()
? [](const hipDataType& val)
Expand Down Expand Up @@ -966,10 +969,10 @@ void hipblaslt_gemm(const Tensor *inputA,
std::cout << "m=" << m << " k=" << k << " n=" << n
<< " transa=" << (param.transA == HIPBLAS_OP_T ? "T" : "N")
<< " transb=" << (param.transB == HIPBLAS_OP_T ? "T" : "N")
<< " A_type=" << (int)(param.Atype)
<< " B_type=" << (int)(param.Btype)
<< " D_type=" << (int)outputD->data.dtype
<< " bias_type=" << (int)inputBias->data.dtype
<< " A_type=" << static_cast<int>(param.Atype)
<< " B_type=" << static_cast<int>(param.Btype)
<< " D_type=" << static_cast<int>(outputD->data.dtype)
<< " bias_type=" << static_cast<int>(inputBias->data.dtype)
<< " grad=" << grad
<< " bias=" << (inputBias->data.dptr != nullptr)
<< " gelu=" << (outputPreGelu->data.dptr != nullptr)
Expand Down Expand Up @@ -1386,7 +1389,7 @@ void hipblaslt_gemm(const Tensor *inputA,
}


typedef unsigned long long ServiceStreamKey;
using ServiceStreamKey = std::uint64_t;

ServiceStreamKey make_service_stream_key(const int device_id, const int cu_count) {
return (static_cast<ServiceStreamKey>(device_id) << 32) | static_cast<ServiceStreamKey>(cu_count);
Expand Down
28 changes: 15 additions & 13 deletions transformer_engine/common/normalization/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,24 +544,26 @@ NormalizationPlanBase* NormalizationPlanRegistry::getNormalizationPlan(
plan = std::make_unique<CudnnNormalizationPlan>(NormType, NormStage, wtype, itype, otype, ctype,
batch_size, hidden_size, sm_count,
zero_centered_gamma, mode, training);
} else
}
#endif
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's get rid of nested ifs. If splitting else-if does not work, better add dummy 'if (false) {' for ROCm instead of 'if (NOrmBAckend...'

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dropped the nested if (!plan) / inner if (Forward) structure. and added the following more clear structure:

CUDA: if (Cudnn) … else if (Forward TE) … else (backward TE) in one #ifndef HIP_PLATFORM_AMD block.
ROCm: if (Forward TE) … else (backward TE) in #else, with mode/training on TE constructors.

if (NormStage == NVTE_Norm_Stage::Forward) {
plan = std::make_unique<TeNormalizationPlan<ForwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned
if (!plan) {
if (NormStage == NVTE_Norm_Stage::Forward) {
plan = std::make_unique<TeNormalizationPlan<ForwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned
#ifdef __HIP_PLATFORM_AMD__
, mode, training
, mode, training
#endif
);
} else {
plan = std::make_unique<TeNormalizationPlan<BackwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned
);
} else {
plan = std::make_unique<TeNormalizationPlan<BackwardKernelParams>>(
NormType, NormStage, wtype, itype, otype, ctype, batch_size, hidden_size, sm_count,
zero_centered_gamma, is_tuned
#ifdef __HIP_PLATFORM_AMD__
, mode, training
, mode, training
#endif
);
);
}
}
normalizationPlanMap.insert({key, std::move(plan)});
return normalizationPlanMap[key].get();
Expand Down
14 changes: 6 additions & 8 deletions transformer_engine/common/normalization/layernorm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
CheckOutputTensor(*rsigma, "rsigma");
}

NVTE_Norm_Backend norm_backend;
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
bool is_aligned = true;
#ifndef __HIP_PLATFORM_AMD__
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
Expand All @@ -85,10 +85,9 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else
}
#endif //__HIP_PLATFORM_AMD__
{
norm_backend = NVTE_Norm_Backend::Te;
if (norm_backend == NVTE_Norm_Backend::Te) {
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, beta.data.dptr,
mu->data.dptr, rsigma->data.dptr);
}
Expand Down Expand Up @@ -169,18 +168,17 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te
CheckOutputTensor(*dbeta, "dbeta");
}

NVTE_Norm_Backend norm_backend;
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
#ifndef __HIP_PLATFORM_AMD__
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else
}
#endif
{
norm_backend = NVTE_Norm_Backend::Te;
if (norm_backend == NVTE_Norm_Backend::Te) {
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr,
dx->data.dptr, dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
CheckOutputTensor(*rsigma, "rsigma");
}

NVTE_Norm_Backend norm_backend;
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
bool is_aligned = true;
#ifndef __HIP_PLATFORM_AMD__
bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp8_scaling(z->scaling_mode);
Expand All @@ -76,10 +76,9 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else
}
#endif
{
norm_backend = NVTE_Norm_Backend::Te;
if (norm_backend == NVTE_Norm_Backend::Te) {
is_aligned = is_ptr_aligned(z->data.dptr, x.data.dptr, gamma.data.dptr, rsigma->data.dptr);
}

Expand Down Expand Up @@ -148,18 +147,17 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const
CheckOutputTensor(*dgamma, "dgamma");
}

NVTE_Norm_Backend norm_backend;
NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te;
bool is_aligned = true;
bool gamma_in_weight_dtype = false;
#ifndef __HIP_PLATFORM_AMD__
if (use_cudnn_norm_bwd()) {
// TODO: add check for GPU ARCH
norm_backend = NVTE_Norm_Backend::Cudnn;
gamma_in_weight_dtype = use_zero_centered_gamma_in_weight_dtype();
} else
}
#endif
{
norm_backend = NVTE_Norm_Backend::Te;
if (norm_backend == NVTE_Norm_Backend::Te) {
is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr,
dz.data.dptr, dgamma->data.dptr);
}
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/permutation/permutation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,12 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
#endif

for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = float(TCompute(frag_load_store_ptr[e]));
frag_sum[e] = static_cast<float>(TCompute(frag_load_store_ptr[e]));
}

if (hasProb) {
for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] = frag_sum[e] * float(s_prob[0]);
frag_sum[e] = frag_sum[e] * static_cast<float>(s_prob[0]);
}
}
} else {
Expand Down Expand Up @@ -120,7 +120,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
}

for (int e = 0; e < kElementsPerAccess; e++) {
frag_sum[e] += float(frag_elem[e]);
frag_sum[e] += static_cast<float>(frag_elem[e]);
}
}

Expand All @@ -129,7 +129,7 @@ __global__ void moe_unpermute_kernel(const T *input, T *unpermuted_output, const
for (int e = 0; e < kElementsPerAccess; e++) {
if constexpr ((std::is_same_v<T, transformer_engine::fp8e4m3> || std::is_same_v<T, transformer_engine::fp8e5m2>) &&
(!hasProb)) {
frag_sum[e] = frag_sum[e] / float(TCompute(topK));
frag_sum[e] = frag_sum[e] / static_cast<float>(TCompute(topK));
}
frag_load_store_ptr[e] = T(TCompute(frag_sum[e]));
}
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/util/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
const hipblasStatus_t status_NVTE_CHECK_CUBLAS = (expr); \
if (status_NVTE_CHECK_CUBLAS != CUBLAS_STATUS_SUCCESS) { \
NVTE_ERROR("HIPBLASLT Error: ", \
std::to_string((int)status_NVTE_CHECK_CUBLAS)); \
std::to_string(static_cast<int>(status_NVTE_CHECK_CUBLAS))); \
} \
} while (false)
#else //cublas
Expand Down
Loading