Skip to content

mca/op: always define aarch64 macros #13246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ompi/mca/op/aarch64/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ specialized_op_libs =
if MCA_BUILD_ompi_op_has_neon_support
specialized_op_libs += liblocal_ops_neon.la
liblocal_ops_neon_la_SOURCES = op_aarch64_functions.c
liblocal_ops_neon_la_CPPFLAGS = -DGENERATE_NEON_CODE
liblocal_ops_neon_la_CPPFLAGS = -DGENERATE_NEON_CODE=1 -DGENERATE_SVE_CODE=0
endif
if MCA_BUILD_ompi_op_has_sve_support
specialized_op_libs += liblocal_ops_sve.la
liblocal_ops_sve_la_SOURCES = op_aarch64_functions.c
liblocal_ops_sve_la_CPPFLAGS = -DGENERATE_SVE_CODE
liblocal_ops_sve_la_CPPFLAGS = -DGENERATE_NEON_CODE=0 -DGENERATE_SVE_CODE=1
endif

component_noinst = $(specialized_op_libs)
Expand Down
34 changes: 4 additions & 30 deletions ompi/mca/op/aarch64/configure.m4
Original file line number Diff line number Diff line change
Expand Up @@ -48,29 +48,6 @@ AC_DEFUN([MCA_ompi_op_aarch64_CONFIG],[
[op_cv_neon_support=yes],
[op_cv_neon_support=no])])

#
# Check for NEON FP support
#
AC_CACHE_CHECK([for NEON FP support], op_cv_neon_fp_support,
[AS_IF([test "$op_cv_neon_support" = "yes"],
[
AC_LINK_IFELSE(
[AC_LANG_PROGRAM([[
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_NEON_FP) || defined(__ARM_FP))
#include <arm_neon.h>
#else
#error "No support for __aarch64__ or NEON FP"
#endif
]],
[[
#if defined(__aarch64__) && defined(__ARM_NEON) && (defined(__ARM_NEON_FP) || defined(__ARM_FP))
float32x4_t vA;
vA = vmovq_n_f32(0)
#endif
]])],
[op_cv_neon_fp_support=yes],
[op_cv_neon_fp_support=no])])])

#
# Check for SVE support
#
Expand Down Expand Up @@ -133,21 +110,18 @@ int main(void) {
])
AM_CONDITIONAL([MCA_BUILD_ompi_op_has_neon_support],
[test "$op_cv_neon_support" = "yes"])
AM_CONDITIONAL([MCA_BUILD_ompi_op_has_neon_fp_support],
[test "$op_cv_neon_fp_support" = "yes"])
AM_CONDITIONAL([MCA_BUILD_ompi_op_has_sve_support],
[test "$op_cv_sve_support" = "yes"])

AC_SUBST(MCA_BUILD_ompi_op_has_neon_support)
AC_SUBST(MCA_BUILD_ompi_op_has_neon_fp_support)
AC_SUBST(MCA_BUILD_ompi_op_has_sve_support)

AS_IF([test "$op_cv_neon_support" = "yes"],
[AC_DEFINE([OMPI_MCA_OP_HAVE_NEON], [1],[NEON supported in the current build])])
AS_IF([test "$op_cv_neon_fp_support" = "yes"],
[AC_DEFINE([OMPI_MCA_OP_HAVE_NEON_FP], [1],[NEON FP supported in the current build])])
[AC_DEFINE([OMPI_MCA_OP_HAVE_NEON], [1],[NEON supported in the current build])],
[AC_DEFINE([OMPI_MCA_OP_HAVE_NEON], [0],[NEON not supported in the current build])])
AS_IF([test "$op_cv_sve_support" = "yes"],
[AC_DEFINE([OMPI_MCA_OP_HAVE_SVE], [1],[SVE supported in the current build])])
[AC_DEFINE([OMPI_MCA_OP_HAVE_SVE], [1],[SVE supported in the current build])],
[AC_DEFINE([OMPI_MCA_OP_HAVE_SVE], [0],[SVE not supported in the current build])])
AS_IF([test "$op_cv_sve_add_flags" = "yes"],
[AC_DEFINE([OMPI_MCA_OP_SVE_EXTRA_FLAGS], [1],[SVE supported with additional compile attributes])],
[AC_DEFINE([OMPI_MCA_OP_SVE_EXTRA_FLAGS], [0],[SVE not supported])])
Expand Down
24 changes: 12 additions & 12 deletions ompi/mca/op/aarch64/op_aarch64_component.c
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,12 @@ OMPI_SVE_ATTR static int mca_op_aarch64_component_register(void)
{

mca_op_aarch64_component.hardware_available = 1; /* Check for Neon */
#if defined(OMPI_MCA_OP_HAVE_SVE)
#if OMPI_MCA_OP_HAVE_SVE
uint64_t id_aa64pfr0_el1 = (1UL << 32);
__asm__("mrs %0, ID_AA64PFR0_EL1" : "=r"(id_aa64pfr0_el1) : :);
/* Check for SVE support */
mca_op_aarch64_component.hardware_available |= ((id_aa64pfr0_el1 & (1UL << 32)) ? 2 : 0);
#endif /* defined(OMPI_MCA_OP_HAVE_SVE) */
#endif /* OMPI_MCA_OP_HAVE_SVE */
(void) mca_base_component_var_register(&mca_op_aarch64_component.super.opc_version,
"hardware_available",
"Whether the Neon (1) or SVE (2) hardware is available",
Expand All @@ -119,9 +119,9 @@ OMPI_SVE_ATTR static int mca_op_aarch64_component_register(void)
MCA_BASE_VAR_SCOPE_READONLY,
&mca_op_aarch64_component.hardware_available);
uint64_t id_aa64zfr0_el1 = 0;
#if defined(OMPI_MCA_OP_HAVE_SVE)
#if OMPI_MCA_OP_HAVE_SVE
__asm__("mrs %0, ID_AA64ZFR0_EL1" : "=r"(id_aa64zfr0_el1) : :);
#endif /* defined(OMPI_MCA_OP_HAVE_SVE) */
#endif /* OMPI_MCA_OP_HAVE_SVE */
mca_op_aarch64_component.double_supported = id_aa64zfr0_el1 & (1UL << 56);
/* Bit 1: mandatory SVE2 instructions */
/* Bit 2: mandatory SVE2.1 instructions */
Expand All @@ -148,18 +148,18 @@ static int mca_op_aarch64_component_init_query(bool enable_progress_threads,
return OMPI_ERR_NOT_SUPPORTED;
}

#if defined(OMPI_MCA_OP_HAVE_NEON)
#if OMPI_MCA_OP_HAVE_NEON
extern ompi_op_base_handler_fn_t
ompi_op_aarch64_functions_neon[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX];
extern ompi_op_base_3buff_handler_fn_t
ompi_op_aarch64_3buff_functions_neon[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX];
#endif /* defined(OMPI_MCA_OP_HAVE_NEON) */
#if defined(OMPI_MCA_OP_HAVE_SVE)
#endif /* OMPI_MCA_OP_HAVE_NEON */
#if OMPI_MCA_OP_HAVE_SVE
extern ompi_op_base_handler_fn_t
ompi_op_aarch64_functions_sve[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX];
extern ompi_op_base_3buff_handler_fn_t
ompi_op_aarch64_3buff_functions_sve[OMPI_OP_BASE_FORTRAN_OP_MAX][OMPI_OP_BASE_TYPE_MAX];
#endif /* defined(OMPI_MCA_OP_HAVE_SVE) */
#endif /* OMPI_MCA_OP_HAVE_SVE */

/*
* Query whether this component can be used for a specific op
Expand Down Expand Up @@ -189,13 +189,13 @@ static struct ompi_op_base_module_1_0_0_t *
for (int i = 0; i < OMPI_OP_BASE_TYPE_MAX; ++i) {
module->opm_fns[i] = NULL;
module->opm_3buff_fns[i] = NULL;
#if defined(OMPI_MCA_OP_HAVE_SVE)
#if OMPI_MCA_OP_HAVE_SVE
if( mca_op_aarch64_component.hardware_available & 2 ) {
module->opm_fns[i] = ompi_op_aarch64_functions_sve[op->o_f_to_c_index][i];
module->opm_3buff_fns[i] = ompi_op_aarch64_3buff_functions_sve[op->o_f_to_c_index][i];
}
#endif /* defined(OMPI_MCA_OP_HAVE_SVE) */
#if defined(OMPI_MCA_OP_HAVE_NEON)
#endif /* OMPI_MCA_OP_HAVE_SVE */
#if OMPI_MCA_OP_HAVE_NEON
if( mca_op_aarch64_component.hardware_available & 1 ) {
if( NULL == module->opm_fns[i] ) {
module->opm_fns[i] = ompi_op_aarch64_functions_neon[op->o_f_to_c_index][i];
Expand All @@ -204,7 +204,7 @@ static struct ompi_op_base_module_1_0_0_t *
module->opm_3buff_fns[i] = ompi_op_aarch64_3buff_functions_neon[op->o_f_to_c_index][i];
}
}
#endif /* defined(OMPI_MCA_OP_HAVE_NEON) */
#endif /* OMPI_MCA_OP_HAVE_NEON */
}
break;
case OMPI_OP_BASE_FORTRAN_LAND:
Expand Down
75 changes: 41 additions & 34 deletions ompi/mca/op/aarch64/op_aarch64_functions.c
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,25 @@
#include "ompi/mca/op/base/base.h"
#include "ompi/mca/op/aarch64/op_aarch64.h"

#if defined(GENERATE_SVE_CODE)
/**
* Ensure exactly one of GENERATE_SVE_CODE or GENERATE_NEON_CODE is enabled.
* Enabling both is invalid as each builds a separate library. Disabling both
* would leave no implementation to compile.
*/
#if GENERATE_SVE_CODE && GENERATE_NEON_CODE
#error "Never build NEON and SVE within the same library"
#elif GENERATE_SVE_CODE
# include <arm_sve.h>
#define OMPI_OP_TYPE_PREPEND sv
#define OMPI_OP_OP_PREPEND sv
#define APPEND _sve
#elif defined(GENERATE_NEON_CODE)
#elif GENERATE_NEON_CODE
# include <arm_neon.h>
#define OMPI_OP_TYPE_PREPEND
#define OMPI_OP_OP_PREPEND v
#define APPEND _neon
#else
#error we should not reach this
#error "Neither NEON nor SVE code generated. This should never happen"
#endif /* OMPI_MCA_OP_HAVE_SVE */

/*
Expand All @@ -51,7 +58,7 @@
*/
#define OP_CONCAT(A, B) OP_CONCAT_NX(A, B)

#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
# define svcnt(X) \
_Generic((X), \
int8_t: svcntb, \
Expand Down Expand Up @@ -101,7 +108,7 @@ _Generic((*(out)), \
uint64_t: __extension__({ switch ((how_much)) { DUMP2(out, in1, in2) }}), \
float32_t: __extension__({ switch ((how_much)) { DUMP4(out, in1, in2) }}), \
float64_t: __extension__({ switch ((how_much)) { DUMP2(out, in1, in2) }}))
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

/*
* Since all the functions in this file are essentially identical, we
Expand All @@ -111,7 +118,7 @@ _Generic((*(out)), \
* This macro is for (out op in).
*
*/
#if defined(GENERATE_NEON_CODE)
#if GENERATE_NEON_CODE
#define OP_AARCH64_FUNC(name, type_name, type_size, type_cnt, type, op) \
static void OP_CONCAT(ompi_op_aarch64_2buff_##name##_##type##type_size##_t, \
APPEND)(const void *_in, void *_out, int *count, \
Expand All @@ -135,7 +142,7 @@ _Generic((*(out)), \
neon_loop(left_over, out, out, in); \
} \
}
#elif defined(GENERATE_SVE_CODE)
#elif GENERATE_SVE_CODE
#define OP_AARCH64_FUNC(name, type_name, type_size, type_cnt, type, op) \
OMPI_SVE_ATTR \
static void OP_CONCAT(ompi_op_aarch64_2buff_##name##_##type##type_size##_t, APPEND) \
Expand Down Expand Up @@ -169,10 +176,10 @@ _Generic((*(out)), \
OP_AARCH64_FUNC(max, u, 16, 8, uint, max)
OP_AARCH64_FUNC(max, s, 32, 4, int, max)
OP_AARCH64_FUNC(max, u, 32, 4, uint, max)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC(max, s, 64, 2, int, max)
OP_AARCH64_FUNC(max, u, 64, 2, uint, max)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC(max, f, 32, 4, float, max)
OP_AARCH64_FUNC(max, f, 64, 2, float, max)
Expand All @@ -188,10 +195,10 @@ _Generic((*(out)), \
OP_AARCH64_FUNC(min, u, 16, 8, uint, min)
OP_AARCH64_FUNC(min, s, 32, 4, int, min)
OP_AARCH64_FUNC(min, u, 32, 4, uint, min)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC(min, s, 64, 2, int, min)
OP_AARCH64_FUNC(min, u, 64, 2, uint, min)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC(min, f, 32, 4, float, min)
OP_AARCH64_FUNC(min, f, 64, 2, float, min)
Expand Down Expand Up @@ -223,10 +230,10 @@ _Generic((*(out)), \
OP_AARCH64_FUNC(prod, u, 16, 8, uint, mul)
OP_AARCH64_FUNC(prod, s, 32, 4, int, mul)
OP_AARCH64_FUNC(prod, u, 32, 4, uint, mul)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC(prod, s, 64, 2, int, mul)
OP_AARCH64_FUNC(prod, u, 64, 2, uint, mul)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC(prod, f, 32, 4, float, mul)
OP_AARCH64_FUNC(prod, f, 64, 2, float, mul)
Expand Down Expand Up @@ -277,7 +284,7 @@ _Generic((*(out)), \
* This is a three buffer (2 input and 1 output) version of the reduction
* routines, needed for some optimizations.
*/
#if defined(GENERATE_NEON_CODE)
#if GENERATE_NEON_CODE
#define OP_AARCH64_FUNC_3BUFF(name, type_name, type_size, type_cnt, type, op) \
static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPEND) \
(const void *_in1, const void *_in2, void *_out, int *count, \
Expand All @@ -302,7 +309,7 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
neon_loop(left_over, out, in1, in2); \
} \
}
#elif defined(GENERATE_SVE_CODE)
#elif GENERATE_SVE_CODE
#define OP_AARCH64_FUNC_3BUFF(name, type_name, type_size, type_cnt, type, op) \
OMPI_SVE_ATTR \
static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPEND) \
Expand All @@ -324,7 +331,7 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
OP_CONCAT(OMPI_OP_OP_PREPEND, st1)(pred, &out[idx], vdst); \
} \
}
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

/*************************************************************************
* Max
Expand All @@ -337,10 +344,10 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
OP_AARCH64_FUNC_3BUFF(max, u, 16, 8, uint, max)
OP_AARCH64_FUNC_3BUFF(max, s, 32, 4, int, max)
OP_AARCH64_FUNC_3BUFF(max, u, 32, 4, uint, max)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC_3BUFF(max, s, 64, 2, int, max)
OP_AARCH64_FUNC_3BUFF(max, u, 64, 2, uint, max)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC_3BUFF(max, f, 32, 4, float, max)
OP_AARCH64_FUNC_3BUFF(max, f, 64, 2, float, max)
Expand All @@ -356,10 +363,10 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
OP_AARCH64_FUNC_3BUFF(min, u, 16, 8, uint, min)
OP_AARCH64_FUNC_3BUFF(min, s, 32, 4, int, min)
OP_AARCH64_FUNC_3BUFF(min, u, 32, 4, uint, min)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC_3BUFF(min, s, 64, 2, int, min)
OP_AARCH64_FUNC_3BUFF(min, u, 64, 2, uint, min)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC_3BUFF(min, f, 32, 4, float, min)
OP_AARCH64_FUNC_3BUFF(min, f, 64, 2, float, min)
Expand Down Expand Up @@ -392,10 +399,10 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
OP_AARCH64_FUNC_3BUFF(prod, u, 16, 8, uint, mul)
OP_AARCH64_FUNC_3BUFF(prod, s, 32, 4, int, mul)
OP_AARCH64_FUNC_3BUFF(prod, u, 32, 4, uint, mul)
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
OP_AARCH64_FUNC_3BUFF(prod, s, 64, 2, int, mul)
OP_AARCH64_FUNC_3BUFF(prod, u, 64, 2, uint, mul)
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */

OP_AARCH64_FUNC_3BUFF(prod, f, 32, 4, float, mul)
OP_AARCH64_FUNC_3BUFF(prod, f, 64, 2, float, mul)
Expand Down Expand Up @@ -482,17 +489,17 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
/* Corresponds to MPI_MAX */
[OMPI_OP_BASE_FORTRAN_MAX] = {
C_INTEGER_BASE(max, 2buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(max, 2buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(max, 2buff),
},
/* Corresponds to MPI_MIN */
[OMPI_OP_BASE_FORTRAN_MIN] = {
C_INTEGER_BASE(min, 2buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(min, 2buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(min, 2buff),
},
/* Corresponds to MPI_SUM */
Expand All @@ -504,9 +511,9 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
/* Corresponds to MPI_PROD */
[OMPI_OP_BASE_FORTRAN_PROD] = {
C_INTEGER_BASE(prod, 2buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(prod, 2buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(prod, 2buff),
},
/* Corresponds to MPI_LAND */
Expand Down Expand Up @@ -558,17 +565,17 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
/* Corresponds to MPI_MAX */
[OMPI_OP_BASE_FORTRAN_MAX] = {
C_INTEGER_BASE(max, 3buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(max, 3buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(max, 3buff),
},
/* Corresponds to MPI_MIN */
[OMPI_OP_BASE_FORTRAN_MIN] = {
C_INTEGER_BASE(min, 3buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(min, 3buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(min, 3buff),
},
/* Corresponds to MPI_SUM */
Expand All @@ -580,9 +587,9 @@ static void OP_CONCAT(ompi_op_aarch64_3buff_##name##_##type##type_size##_t, APPE
/* Corresponds to MPI_PROD */
[OMPI_OP_BASE_FORTRAN_PROD] = {
C_INTEGER_BASE(prod, 3buff),
#if defined(GENERATE_SVE_CODE)
#if GENERATE_SVE_CODE
C_INTEGER_EX(prod, 3buff),
#endif /* defined(GENERATE_SVE_CODE) */
#endif /* GENERATE_SVE_CODE */
FLOATING_POINT(prod, 3buff),
},
/* Corresponds to MPI_LAND */
Expand Down