Skip to content

Commit 36a76f7

Browse files
committed
Modify per code review sugguestions
1 parent 3e35f2f commit 36a76f7

File tree

2 files changed

+102
-105
lines changed

2 files changed

+102
-105
lines changed

CMakeLists.txt

Lines changed: 3 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -68,111 +68,9 @@ option(LLAMA_OPENBLAS "llama: use OpenBLAS"
6868
option(LLAMA_BUILD_TESTS "llama: build tests" ${LLAMA_STANDALONE})
6969
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
7070

71-
INCLUDE(CheckCSourceRuns)
72-
73-
SET(AVX_CODE "
74-
#include <immintrin.h>
75-
int main()
76-
{
77-
__m256 a;
78-
a = _mm256_set1_ps(0);
79-
return 0;
80-
}
81-
")
82-
83-
SET(AVX512_CODE "
84-
#include <immintrin.h>
85-
int main()
86-
{
87-
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
88-
0, 0, 0, 0, 0, 0, 0, 0,
89-
0, 0, 0, 0, 0, 0, 0, 0,
90-
0, 0, 0, 0, 0, 0, 0, 0,
91-
0, 0, 0, 0, 0, 0, 0, 0,
92-
0, 0, 0, 0, 0, 0, 0, 0,
93-
0, 0, 0, 0, 0, 0, 0, 0,
94-
0, 0, 0, 0, 0, 0, 0, 0);
95-
__m512i b = a;
96-
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
97-
return 0;
98-
}
99-
")
100-
101-
SET(AVX2_CODE "
102-
#include <immintrin.h>
103-
int main()
104-
{
105-
__m256i a = {0};
106-
a = _mm256_abs_epi16(a);
107-
__m256i x;
108-
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
109-
return 0;
110-
}
111-
")
112-
113-
SET(FMA_CODE "
114-
#include <immintrin.h>
115-
int main()
116-
{
117-
__m256 acc = _mm256_setzero_ps();
118-
const __m256 d = _mm256_setzero_ps();
119-
const __m256 p = _mm256_setzero_ps();
120-
acc = _mm256_fmadd_ps( d, p, acc );
121-
return 0;
122-
}
123-
")
124-
125-
MACRO(CHECK_SSE type flags)
126-
SET(__FLAG_I 1)
127-
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
128-
FOREACH(__FLAG ${flags})
129-
IF(NOT ${type}_FOUND)
130-
SET(CMAKE_REQUIRED_FLAGS ${__FLAG})
131-
CHECK_C_SOURCE_RUNS("${${type}_CODE}" HAS_${type}_${__FLAG_I})
132-
IF(HAS_${type}_${__FLAG_I})
133-
SET(${type}_FOUND TRUE CACHE BOOL "${type} support")
134-
SET(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
135-
ENDIF()
136-
MATH(EXPR __FLAG_I "${__FLAG_I}+1")
137-
ENDIF()
138-
ENDFOREACH()
139-
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
140-
141-
IF(NOT ${type}_FOUND)
142-
SET(${type}_FOUND FALSE CACHE BOOL "${type} support")
143-
SET(${type}_FLAGS "" CACHE STRING "${type} flags")
144-
ENDIF()
145-
146-
MARK_AS_ADVANCED(${type}_FOUND ${type}_FLAGS)
147-
148-
ENDMACRO()
149-
150-
IF(${LLAMA_AVX})
151-
CHECK_SSE("AVX" " ;-mavx;/arch:AVX")
152-
IF(NOT ${AVX_FOUND})
153-
set(LLAMA_AVX OFF)
154-
ENDIF()
155-
ENDIF()
156-
157-
IF(${LLAMA_AVX2})
158-
CHECK_SSE("AVX2" " ;-mavx2 -mfma;/arch:AVX2")
159-
IF(NOT ${AVX2_FOUND})
160-
set(LLAMA_AVX2 OFF)
161-
ENDIF()
162-
ENDIF()
163-
164-
IF(${LLAMA_AVX512})
165-
CHECK_SSE("AVX512" " ;-mavx512f -mavx512dq -mavx512vl -mavx512bw -mfma;/arch:AVX512")
166-
IF(NOT ${AVX512_FOUND})
167-
set(LLAMA_AVX512 OFF)
168-
ENDIF()
169-
ENDIF()
170-
171-
IF(${LLAMA_FMA})
172-
CHECK_SSE("FMA" " ;-mfma;")
173-
IF (NOT ${FMA_FOUND})
174-
set(LLAMA_FMA OFF)
175-
ENDIF()
71+
MESSAGE("NATIVE=" ${LLAMA_NATIVE} " MSVC=" ${MSVC})
72+
IF(LLAMA_NATIVE AND MSVC)
73+
include(cmake/FindSIMD.cmake)
17674
ENDIF()
17775

17876
#

cmake/FindSIMD.cmake

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
INCLUDE(CheckCSourceRuns)
2+
3+
SET(AVX_CODE "
4+
#include <immintrin.h>
5+
int main()
6+
{
7+
__m256 a;
8+
a = _mm256_set1_ps(0);
9+
return 0;
10+
}
11+
")
12+
13+
SET(AVX512_CODE "
14+
#include <immintrin.h>
15+
int main()
16+
{
17+
__m512i a = _mm512_set_epi8(0, 0, 0, 0, 0, 0, 0, 0,
18+
0, 0, 0, 0, 0, 0, 0, 0,
19+
0, 0, 0, 0, 0, 0, 0, 0,
20+
0, 0, 0, 0, 0, 0, 0, 0,
21+
0, 0, 0, 0, 0, 0, 0, 0,
22+
0, 0, 0, 0, 0, 0, 0, 0,
23+
0, 0, 0, 0, 0, 0, 0, 0,
24+
0, 0, 0, 0, 0, 0, 0, 0);
25+
__m512i b = a;
26+
__mmask64 equality_mask = _mm512_cmp_epi8_mask(a, b, _MM_CMPINT_EQ);
27+
return 0;
28+
}
29+
")
30+
31+
SET(AVX2_CODE "
32+
#include <immintrin.h>
33+
int main()
34+
{
35+
__m256i a = {0};
36+
a = _mm256_abs_epi16(a);
37+
__m256i x;
38+
_mm256_extract_epi64(x, 0); // we rely on this in our AVX2 code
39+
return 0;
40+
}
41+
")
42+
43+
SET(FMA_CODE "
44+
#include <immintrin.h>
45+
int main()
46+
{
47+
__m256 acc = _mm256_setzero_ps();
48+
const __m256 d = _mm256_setzero_ps();
49+
const __m256 p = _mm256_setzero_ps();
50+
acc = _mm256_fmadd_ps( d, p, acc );
51+
return 0;
52+
}
53+
")
54+
55+
MACRO(CHECK_SSE type flags)
56+
SET(__FLAG_I 1)
57+
SET(CMAKE_REQUIRED_FLAGS_SAVE ${CMAKE_REQUIRED_FLAGS})
58+
FOREACH(__FLAG ${flags})
59+
IF(NOT ${type}_FOUND)
60+
SET(CMAKE_REQUIRED_FLAGS ${__FLAG})
61+
CHECK_C_SOURCE_RUNS("${${type}_CODE}" HAS_${type}_${__FLAG_I})
62+
IF(HAS_${type}_${__FLAG_I})
63+
SET(${type}_FOUND TRUE CACHE BOOL "${type} support")
64+
SET(${type}_FLAGS "${__FLAG}" CACHE STRING "${type} flags")
65+
ENDIF()
66+
MATH(EXPR __FLAG_I "${__FLAG_I}+1")
67+
ENDIF()
68+
ENDFOREACH()
69+
SET(CMAKE_REQUIRED_FLAGS ${CMAKE_REQUIRED_FLAGS_SAVE})
70+
71+
IF(NOT ${type}_FOUND)
72+
SET(${type}_FOUND FALSE CACHE BOOL "${type} support")
73+
SET(${type}_FLAGS "" CACHE STRING "${type} flags")
74+
ENDIF()
75+
76+
MARK_AS_ADVANCED(${type}_FOUND ${type}_FLAGS)
77+
78+
ENDMACRO()
79+
80+
CHECK_SSE("AVX" " ;/arch:AVX")
81+
IF(NOT ${AVX_FOUND})
82+
set(LLAMA_AVX OFF)
83+
ELSE()
84+
set(LLAMA_AVX ON)
85+
ENDIF()
86+
87+
CHECK_SSE("AVX2" " ;/arch:AVX2")
88+
IF(NOT ${AVX2_FOUND})
89+
set(LLAMA_AVX2 OFF)
90+
ELSE()
91+
set(LLAMA_AVX2 ON)
92+
ENDIF()
93+
94+
CHECK_SSE("AVX512" " ;/arch:AVX512")
95+
IF(NOT ${AVX512_FOUND})
96+
set(LLAMA_AVX512 OFF)
97+
ELSE()
98+
set(LLAMA_AVX512 ON)
99+
ENDIF()

0 commit comments

Comments
 (0)