Skip to content

Commit 250f2bb

Browse files
kushanamArtem-B
andcommitted
adding bf16 support to NVPTX
Currently, bf16 has been scatteredly added to the PTX codegen. This patch aims to complete the set of instructions and code path required to support bf16 data type. Reviewed By: tra Differential Revision: https://reviews.llvm.org/D144911 Co-authored-by: Artem Belevich <[email protected]>
1 parent 85bdea0 commit 250f2bb

24 files changed

+1706
-370
lines changed

clang/include/clang/Basic/BuiltinsNVPTX.def

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -173,16 +173,20 @@ TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
173173
AND(SM_86, PTX72))
174174
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
175175
AND(SM_86, PTX72))
176-
TARGET_BUILTIN(__nvvm_fmin_bf16, "UsUsUs", "", AND(SM_80, PTX70))
177-
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
178-
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
179-
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "UsUsUs", "",
176+
TARGET_BUILTIN(__nvvm_fmin_bf16, "yyy", "", AND(SM_80, PTX70))
177+
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
178+
TARGET_BUILTIN(__nvvm_fmin_nan_bf16, "yyy", "", AND(SM_80, PTX70))
179+
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
180+
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
181+
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16, "yyy", "",
180182
AND(SM_86, PTX72))
181-
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
182-
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
183-
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
183+
TARGET_BUILTIN(__nvvm_fmin_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
184+
TARGET_BUILTIN(__nvvm_fmin_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
185+
TARGET_BUILTIN(__nvvm_fmin_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
186+
TARGET_BUILTIN(__nvvm_fmin_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
187+
TARGET_BUILTIN(__nvvm_fmin_xorsign_abs_bf16x2, "V2yV2yV2y", "",
184188
AND(SM_86, PTX72))
185-
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
189+
TARGET_BUILTIN(__nvvm_fmin_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
186190
AND(SM_86, PTX72))
187191
BUILTIN(__nvvm_fmin_f, "fff", "")
188192
BUILTIN(__nvvm_fmin_ftz_f, "fff", "")
@@ -215,16 +219,20 @@ TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
215219
AND(SM_86, PTX72))
216220
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_xorsign_abs_f16x2, "V2hV2hV2h", "",
217221
AND(SM_86, PTX72))
218-
TARGET_BUILTIN(__nvvm_fmax_bf16, "UsUsUs", "", AND(SM_80, PTX70))
219-
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "UsUsUs", "", AND(SM_80, PTX70))
220-
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "UsUsUs", "", AND(SM_86, PTX72))
221-
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "UsUsUs", "",
222+
TARGET_BUILTIN(__nvvm_fmax_bf16, "yyy", "", AND(SM_80, PTX70))
223+
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16, "yyy", "", AND(SM_80, PTX70))
224+
TARGET_BUILTIN(__nvvm_fmax_nan_bf16, "yyy", "", AND(SM_80, PTX70))
225+
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16, "yyy", "", AND(SM_80, PTX70))
226+
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16, "yyy", "", AND(SM_86, PTX72))
227+
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16, "yyy", "",
222228
AND(SM_86, PTX72))
223-
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
224-
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "ZUiZUiZUi", "", AND(SM_80, PTX70))
225-
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
229+
TARGET_BUILTIN(__nvvm_fmax_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
230+
TARGET_BUILTIN(__nvvm_fmax_ftz_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
231+
TARGET_BUILTIN(__nvvm_fmax_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
232+
TARGET_BUILTIN(__nvvm_fmax_ftz_nan_bf16x2, "V2yV2yV2y", "", AND(SM_80, PTX70))
233+
TARGET_BUILTIN(__nvvm_fmax_xorsign_abs_bf16x2, "V2yV2yV2y", "",
226234
AND(SM_86, PTX72))
227-
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "ZUiZUiZUi", "",
235+
TARGET_BUILTIN(__nvvm_fmax_nan_xorsign_abs_bf16x2, "V2yV2yV2y", "",
228236
AND(SM_86, PTX72))
229237
BUILTIN(__nvvm_fmax_f, "fff", "")
230238
BUILTIN(__nvvm_fmax_ftz_f, "fff", "")
@@ -352,10 +360,10 @@ TARGET_BUILTIN(__nvvm_fma_rn_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
352360
TARGET_BUILTIN(__nvvm_fma_rn_ftz_sat_f16x2, "V2hV2hV2hV2h", "", AND(SM_53, PTX42))
353361
TARGET_BUILTIN(__nvvm_fma_rn_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
354362
TARGET_BUILTIN(__nvvm_fma_rn_ftz_relu_f16x2, "V2hV2hV2hV2h", "", AND(SM_80, PTX70))
355-
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
356-
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "UsUsUsUs", "", AND(SM_80, PTX70))
357-
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
358-
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "ZUiZUiZUiZUi", "", AND(SM_80, PTX70))
363+
TARGET_BUILTIN(__nvvm_fma_rn_bf16, "yyyy", "", AND(SM_80, PTX70))
364+
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16, "yyyy", "", AND(SM_80, PTX70))
365+
TARGET_BUILTIN(__nvvm_fma_rn_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
366+
TARGET_BUILTIN(__nvvm_fma_rn_relu_bf16x2, "V2yV2yV2yV2y", "", AND(SM_80, PTX70))
359367
BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
360368
BUILTIN(__nvvm_fma_rn_f, "ffff", "")
361369
BUILTIN(__nvvm_fma_rz_ftz_f, "ffff", "")
@@ -543,20 +551,20 @@ BUILTIN(__nvvm_ull2d_rp, "dULLi", "")
543551
BUILTIN(__nvvm_f2h_rn_ftz, "Usf", "")
544552
BUILTIN(__nvvm_f2h_rn, "Usf", "")
545553

546-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "ZUiff", "", AND(SM_80,PTX70))
547-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "ZUiff", "", AND(SM_80,PTX70))
548-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "ZUiff", "", AND(SM_80,PTX70))
549-
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "ZUiff", "", AND(SM_80,PTX70))
554+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn, "V2yff", "", AND(SM_80,PTX70))
555+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rn_relu, "V2yff", "", AND(SM_80,PTX70))
556+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz, "V2yff", "", AND(SM_80,PTX70))
557+
TARGET_BUILTIN(__nvvm_ff2bf16x2_rz_relu, "V2yff", "", AND(SM_80,PTX70))
550558

551559
TARGET_BUILTIN(__nvvm_ff2f16x2_rn, "V2hff", "", AND(SM_80,PTX70))
552560
TARGET_BUILTIN(__nvvm_ff2f16x2_rn_relu, "V2hff", "", AND(SM_80,PTX70))
553561
TARGET_BUILTIN(__nvvm_ff2f16x2_rz, "V2hff", "", AND(SM_80,PTX70))
554562
TARGET_BUILTIN(__nvvm_ff2f16x2_rz_relu, "V2hff", "", AND(SM_80,PTX70))
555563

556-
TARGET_BUILTIN(__nvvm_f2bf16_rn, "ZUsf", "", AND(SM_80,PTX70))
557-
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "ZUsf", "", AND(SM_80,PTX70))
558-
TARGET_BUILTIN(__nvvm_f2bf16_rz, "ZUsf", "", AND(SM_80,PTX70))
559-
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "ZUsf", "", AND(SM_80,PTX70))
564+
TARGET_BUILTIN(__nvvm_f2bf16_rn, "yf", "", AND(SM_80,PTX70))
565+
TARGET_BUILTIN(__nvvm_f2bf16_rn_relu, "yf", "", AND(SM_80,PTX70))
566+
TARGET_BUILTIN(__nvvm_f2bf16_rz, "yf", "", AND(SM_80,PTX70))
567+
TARGET_BUILTIN(__nvvm_f2bf16_rz_relu, "yf", "", AND(SM_80,PTX70))
560568

561569
TARGET_BUILTIN(__nvvm_f2tf32_rna, "ZUif", "", AND(SM_80,PTX70))
562570

@@ -1024,10 +1032,10 @@ TARGET_BUILTIN(__nvvm_cp_async_wait_all, "v", "", AND(SM_80,PTX70))
10241032

10251033

10261034
// bf16, bf16x2 abs, neg
1027-
TARGET_BUILTIN(__nvvm_abs_bf16, "UsUs", "", AND(SM_80,PTX70))
1028-
TARGET_BUILTIN(__nvvm_abs_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
1029-
TARGET_BUILTIN(__nvvm_neg_bf16, "UsUs", "", AND(SM_80,PTX70))
1030-
TARGET_BUILTIN(__nvvm_neg_bf16x2, "ZUiZUi", "", AND(SM_80,PTX70))
1035+
TARGET_BUILTIN(__nvvm_abs_bf16, "yy", "", AND(SM_80,PTX70))
1036+
TARGET_BUILTIN(__nvvm_abs_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
1037+
TARGET_BUILTIN(__nvvm_neg_bf16, "yy", "", AND(SM_80,PTX70))
1038+
TARGET_BUILTIN(__nvvm_neg_bf16x2, "V2yV2y", "", AND(SM_80,PTX70))
10311039

10321040
TARGET_BUILTIN(__nvvm_mapa, "v*v*i", "", AND(SM_90, PTX78))
10331041
TARGET_BUILTIN(__nvvm_mapa_shared_cluster, "v*3v*3i", "", AND(SM_90, PTX78))

clang/test/CodeGen/builtins-nvptx.c

Lines changed: 65 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -899,13 +899,13 @@ __device__ void nvvm_async_copy(__attribute__((address_space(3))) void* dst, __a
899899
// CHECK-LABEL: nvvm_cvt_sm80
900900
__device__ void nvvm_cvt_sm80() {
901901
#if __CUDA_ARCH__ >= 800
902-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
902+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn(float 1.000000e+00, float 1.000000e+00)
903903
__nvvm_ff2bf16x2_rn(1, 1);
904-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
904+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rn.relu(float 1.000000e+00, float 1.000000e+00)
905905
__nvvm_ff2bf16x2_rn_relu(1, 1);
906-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
906+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz(float 1.000000e+00, float 1.000000e+00)
907907
__nvvm_ff2bf16x2_rz(1, 1);
908-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
908+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.ff2bf16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
909909
__nvvm_ff2bf16x2_rz_relu(1, 1);
910910

911911
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rn(float 1.000000e+00, float 1.000000e+00)
@@ -917,13 +917,13 @@ __device__ void nvvm_cvt_sm80() {
917917
// CHECK_PTX70_SM80: call <2 x half> @llvm.nvvm.ff2f16x2.rz.relu(float 1.000000e+00, float 1.000000e+00)
918918
__nvvm_ff2f16x2_rz_relu(1, 1);
919919

920-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
920+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn(float 1.000000e+00)
921921
__nvvm_f2bf16_rn(1);
922-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
922+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rn.relu(float 1.000000e+00)
923923
__nvvm_f2bf16_rn_relu(1);
924-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
924+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz(float 1.000000e+00)
925925
__nvvm_f2bf16_rz(1);
926-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
926+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.f2bf16.rz.relu(float 1.000000e+00)
927927
__nvvm_f2bf16_rz_relu(1);
928928

929929
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.f2tf32.rna(float 1.000000e+00)
@@ -932,32 +932,32 @@ __device__ void nvvm_cvt_sm80() {
932932
// CHECK: ret void
933933
}
934934

935+
#define NAN32 0x7FBFFFFF
936+
#define NAN16 (__bf16)0x7FBF
937+
#define BF16 (__bf16)0.1f
938+
#define BF16_2 (__bf16)0.2f
939+
#define NANBF16 (__bf16)0xFFC1
940+
#define BF16X2 {(__bf16)0.1f, (__bf16)0.1f}
941+
#define BF16X2_2 {(__bf16)0.2f, (__bf16)0.2f}
942+
#define NANBF16X2 {NANBF16, NANBF16}
943+
935944
// CHECK-LABEL: nvvm_abs_neg_bf16_bf16x2_sm80
936945
__device__ void nvvm_abs_neg_bf16_bf16x2_sm80() {
937946
#if __CUDA_ARCH__ >= 800
938947

939-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.abs.bf16(i16 -1)
940-
__nvvm_abs_bf16(0xFFFF);
941-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.abs.bf16x2(i32 -1)
942-
__nvvm_abs_bf16x2(0xFFFFFFFF);
948+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.abs.bf16(bfloat 0xR3DCD)
949+
__nvvm_abs_bf16(BF16);
950+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.abs.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
951+
__nvvm_abs_bf16x2(BF16X2);
943952

944-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.neg.bf16(i16 -1)
945-
__nvvm_neg_bf16(0xFFFF);
946-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.neg.bf16x2(i32 -1)
947-
__nvvm_neg_bf16x2(0xFFFFFFFF);
953+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.neg.bf16(bfloat 0xR3DCD)
954+
__nvvm_neg_bf16(BF16);
955+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.neg.bf16x2(<2 x bfloat> <bfloat 0xR3DCD, bfloat 0xR3DCD>)
956+
__nvvm_neg_bf16x2(BF16X2);
948957
#endif
949958
// CHECK: ret void
950959
}
951960

952-
#define NAN32 0x7FBFFFFF
953-
#define NAN16 0x7FBF
954-
#define BF16 0x1234
955-
#define BF16_2 0x4321
956-
#define NANBF16 0xFFC1
957-
#define BF16X2 0x12341234
958-
#define BF16X2_2 0x32343234
959-
#define NANBF16X2 0xFFC1FFC1
960-
961961
// CHECK-LABEL: nvvm_min_max_sm80
962962
__device__ void nvvm_min_max_sm80() {
963963
#if __CUDA_ARCH__ >= 800
@@ -967,14 +967,22 @@ __device__ void nvvm_min_max_sm80() {
967967
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmin.ftz.nan.f
968968
__nvvm_fmin_ftz_nan_f(0.1f, (float)NAN32);
969969

970-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.bf16
970+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.bf16
971971
__nvvm_fmin_bf16(BF16, BF16_2);
972-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmin.nan.bf16
972+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.bf16
973+
__nvvm_fmin_ftz_bf16(BF16, BF16_2);
974+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.nan.bf16
973975
__nvvm_fmin_nan_bf16(BF16, NANBF16);
974-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.bf16x2
976+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmin.ftz.nan.bf16
977+
__nvvm_fmin_ftz_nan_bf16(BF16, NANBF16);
978+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.bf16x2
975979
__nvvm_fmin_bf16x2(BF16X2, BF16X2_2);
976-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmin.nan.bf16x2
980+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.bf16x2
981+
__nvvm_fmin_ftz_bf16x2(BF16X2, BF16X2_2);
982+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.nan.bf16x2
977983
__nvvm_fmin_nan_bf16x2(BF16X2, NANBF16X2);
984+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmin.ftz.nan.bf16x2
985+
__nvvm_fmin_ftz_nan_bf16x2(BF16X2, NANBF16X2);
978986
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
979987
__nvvm_fmax_nan_f(0.1f, 0.11f);
980988
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -984,14 +992,22 @@ __device__ void nvvm_min_max_sm80() {
984992
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
985993
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
986994
__nvvm_fmax_ftz_nan_f(0.1f, (float)NAN32);
987-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.bf16
995+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.bf16
988996
__nvvm_fmax_bf16(BF16, BF16_2);
989-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fmax.nan.bf16
997+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.bf16
998+
__nvvm_fmax_ftz_bf16(BF16, BF16_2);
999+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.nan.bf16
9901000
__nvvm_fmax_nan_bf16(BF16, NANBF16);
991-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.bf16x2
1001+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fmax.ftz.nan.bf16
1002+
__nvvm_fmax_ftz_nan_bf16(BF16, NANBF16);
1003+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.bf16x2
9921004
__nvvm_fmax_bf16x2(BF16X2, BF16X2_2);
993-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fmax.nan.bf16x2
1005+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.bf16x2
1006+
__nvvm_fmax_ftz_bf16x2(BF16X2, BF16X2_2);
1007+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.nan.bf16x2
9941008
__nvvm_fmax_nan_bf16x2(NANBF16X2, BF16X2);
1009+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fmax.ftz.nan.bf16x2
1010+
__nvvm_fmax_ftz_nan_bf16x2(NANBF16X2, BF16X2);
9951011
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.nan.f
9961012
__nvvm_fmax_nan_f(0.1f, (float)NAN32);
9971013
// CHECK_PTX70_SM80: call float @llvm.nvvm.fmax.ftz.nan.f
@@ -1004,14 +1020,14 @@ __device__ void nvvm_min_max_sm80() {
10041020
// CHECK-LABEL: nvvm_fma_bf16_bf16x2_sm80
10051021
__device__ void nvvm_fma_bf16_bf16x2_sm80() {
10061022
#if __CUDA_ARCH__ >= 800
1007-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.bf16
1008-
__nvvm_fma_rn_bf16(0x1234, 0x7FBF, 0x1234);
1009-
// CHECK_PTX70_SM80: call i16 @llvm.nvvm.fma.rn.relu.bf16
1010-
__nvvm_fma_rn_relu_bf16(0x1234, 0x7FBF, 0x1234);
1011-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.bf16x2
1012-
__nvvm_fma_rn_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
1013-
// CHECK_PTX70_SM80: call i32 @llvm.nvvm.fma.rn.relu.bf16x2
1014-
__nvvm_fma_rn_relu_bf16x2(0x7FBFFFFF, 0xFFFFFFFF, 0x7FBFFFFF);
1023+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.bf16
1024+
__nvvm_fma_rn_bf16(BF16, BF16_2, BF16_2);
1025+
// CHECK_PTX70_SM80: call bfloat @llvm.nvvm.fma.rn.relu.bf16
1026+
__nvvm_fma_rn_relu_bf16(BF16, BF16_2, BF16_2);
1027+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.bf16x2
1028+
__nvvm_fma_rn_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
1029+
// CHECK_PTX70_SM80: call <2 x bfloat> @llvm.nvvm.fma.rn.relu.bf16x2
1030+
__nvvm_fma_rn_relu_bf16x2(BF16X2, BF16X2_2, BF16X2_2);
10151031
#endif
10161032
// CHECK: ret void
10171033
}
@@ -1020,13 +1036,13 @@ __device__ void nvvm_fma_bf16_bf16x2_sm80() {
10201036
__device__ void nvvm_min_max_sm86() {
10211037
#if __CUDA_ARCH__ >= 860
10221038

1023-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.xorsign.abs.bf16
1039+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.xorsign.abs.bf16
10241040
__nvvm_fmin_xorsign_abs_bf16(BF16, BF16_2);
1025-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmin.nan.xorsign.abs.bf16
1041+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmin.nan.xorsign.abs.bf16
10261042
__nvvm_fmin_nan_xorsign_abs_bf16(BF16, NANBF16);
1027-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.xorsign.abs.bf16x2
1043+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.xorsign.abs.bf16x2
10281044
__nvvm_fmin_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
1029-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
1045+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmin.nan.xorsign.abs.bf16x2
10301046
__nvvm_fmin_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
10311047
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.xorsign.abs.f
10321048
__nvvm_fmin_xorsign_abs_f(-0.1f, 0.1f);
@@ -1037,13 +1053,13 @@ __device__ void nvvm_min_max_sm86() {
10371053
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmin.ftz.nan.xorsign.abs.f
10381054
__nvvm_fmin_ftz_nan_xorsign_abs_f(-0.1f, (float)NAN32);
10391055

1040-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.xorsign.abs.bf16
1056+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.xorsign.abs.bf16
10411057
__nvvm_fmax_xorsign_abs_bf16(BF16, BF16_2);
1042-
// CHECK_PTX72_SM86: call i16 @llvm.nvvm.fmax.nan.xorsign.abs.bf16
1058+
// CHECK_PTX72_SM86: call bfloat @llvm.nvvm.fmax.nan.xorsign.abs.bf16
10431059
__nvvm_fmax_nan_xorsign_abs_bf16(BF16, NANBF16);
1044-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.xorsign.abs.bf16x2
1060+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.xorsign.abs.bf16x2
10451061
__nvvm_fmax_xorsign_abs_bf16x2(BF16X2, BF16X2_2);
1046-
// CHECK_PTX72_SM86: call i32 @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
1062+
// CHECK_PTX72_SM86: call <2 x bfloat> @llvm.nvvm.fmax.nan.xorsign.abs.bf16x2
10471063
__nvvm_fmax_nan_xorsign_abs_bf16x2(BF16X2, NANBF16X2);
10481064
// CHECK_PTX72_SM86: call float @llvm.nvvm.fmax.xorsign.abs.f
10491065
__nvvm_fmax_xorsign_abs_f(-0.1f, 0.1f);

clang/test/CodeGenCUDA/bf16.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
// CHECK-LABEL: .visible .func _Z8test_argPDF16bDF16b(
1010
// CHECK: .param .b64 _Z8test_argPDF16bDF16b_param_0,
11-
// CHECK: .param .b16 _Z8test_argPDF16bDF16b_param_1
11+
// CHECK: .param .align 2 .b8 _Z8test_argPDF16bDF16b_param_1[2]
1212
//
1313
__device__ void test_arg(__bf16 *out, __bf16 in) {
1414
// CHECK-DAG: ld.param.u64 %[[A:rd[0-9]+]], [_Z8test_argPDF16bDF16b_param_0];
@@ -20,8 +20,8 @@ __device__ void test_arg(__bf16 *out, __bf16 in) {
2020
}
2121

2222

23-
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z8test_retDF16b(
24-
// CHECK: .param .b16 _Z8test_retDF16b_param_0
23+
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z8test_retDF16b(
24+
// CHECK: .param .align 2 .b8 _Z8test_retDF16b_param_0[2]
2525
__device__ __bf16 test_ret( __bf16 in) {
2626
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z8test_retDF16b_param_0];
2727
return in;
@@ -31,12 +31,12 @@ __device__ __bf16 test_ret( __bf16 in) {
3131

3232
__device__ __bf16 external_func( __bf16 in);
3333

34-
// CHECK-LABEL: .visible .func (.param .b32 func_retval0) _Z9test_callDF16b(
35-
// CHECK: .param .b16 _Z9test_callDF16b_param_0
34+
// CHECK-LABEL: .visible .func (.param .align 2 .b8 func_retval0[2]) _Z9test_callDF16b(
35+
// CHECK: .param .align 2 .b8 _Z9test_callDF16b_param_0[2]
3636
__device__ __bf16 test_call( __bf16 in) {
3737
// CHECK: ld.param.b16 %[[R:rs[0-9]+]], [_Z9test_callDF16b_param_0];
3838
// CHECK: st.param.b16 [param0+0], %[[R]];
39-
// CHECK: .param .b32 retval0;
39+
// CHECK: .param .align 2 .b8 retval0[2];
4040
// CHECK: call.uni (retval0),
4141
// CHECK-NEXT: _Z13external_funcDF16b,
4242
// CHECK-NEXT: (

0 commit comments

Comments
 (0)