Skip to content

Commit bc7a907

Browse files
committed
[ADT] Add implementations for avgFloor and avgCeil to APInt
Supports both signed and unsigned expansions. SelectionDAG now calls the APInt implementation of these functions.
1 parent d9c8550 commit bc7a907

File tree

4 files changed

+123
-24
lines changed

4 files changed

+123
-24
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2193,6 +2193,18 @@ inline const APInt absdiff(const APInt &A, const APInt &B) {
21932193
return A.uge(B) ? (A - B) : (B - A);
21942194
}
21952195

2196+
/// Compute the floor of the signed average of C1 and C2
2197+
APInt avgFloorS(const APInt &C1, const APInt &C2);
2198+
2199+
/// Compute the floor of the unsigned average of C1 and C2
2200+
APInt avgFloorU(const APInt &C1, const APInt &C2);
2201+
2202+
/// Compute the ceil of the signed average of C1 and C2
2203+
APInt avgCeilS(const APInt &C1, const APInt &C2);
2204+
2205+
/// Compute the ceil of the unsigned average of C1 and C2
2206+
APInt avgCeilU(const APInt &C1, const APInt &C2);
2207+
21962208
/// Compute GCD of two unsigned APInt values.
21972209
///
21982210
/// This function returns the greatest common divisor of the two APInt values

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6021,30 +6021,14 @@ static std::optional<APInt> FoldValue(unsigned Opcode, const APInt &C1,
60216021
APInt C2Ext = C2.zext(FullWidth);
60226022
return (C1Ext * C2Ext).extractBits(C1.getBitWidth(), C1.getBitWidth());
60236023
}
6024-
case ISD::AVGFLOORS: {
6025-
unsigned FullWidth = C1.getBitWidth() + 1;
6026-
APInt C1Ext = C1.sext(FullWidth);
6027-
APInt C2Ext = C2.sext(FullWidth);
6028-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6029-
}
6030-
case ISD::AVGFLOORU: {
6031-
unsigned FullWidth = C1.getBitWidth() + 1;
6032-
APInt C1Ext = C1.zext(FullWidth);
6033-
APInt C2Ext = C2.zext(FullWidth);
6034-
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
6035-
}
6036-
case ISD::AVGCEILS: {
6037-
unsigned FullWidth = C1.getBitWidth() + 1;
6038-
APInt C1Ext = C1.sext(FullWidth);
6039-
APInt C2Ext = C2.sext(FullWidth);
6040-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6041-
}
6042-
case ISD::AVGCEILU: {
6043-
unsigned FullWidth = C1.getBitWidth() + 1;
6044-
APInt C1Ext = C1.zext(FullWidth);
6045-
APInt C2Ext = C2.zext(FullWidth);
6046-
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
6047-
}
6024+
case ISD::AVGFLOORS:
6025+
return APIntOps::avgFloorS(C1, C2);
6026+
case ISD::AVGFLOORU:
6027+
return APIntOps::avgFloorU(C1, C2);
6028+
case ISD::AVGCEILS:
6029+
return APIntOps::avgCeilS(C1, C2);
6030+
case ISD::AVGCEILU:
6031+
return APIntOps::avgCeilU(C1, C2);
60486032
case ISD::ABDS:
60496033
return APIntOps::smax(C1, C2) - APIntOps::smin(C1, C2);
60506034
case ISD::ABDU:

llvm/lib/Support/APInt.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3094,3 +3094,35 @@ void llvm::LoadIntFromMemory(APInt &IntVal, const uint8_t *Src,
30943094
memcpy(Dst + sizeof(uint64_t) - LoadBytes, Src, LoadBytes);
30953095
}
30963096
}
3097+
3098+
APInt APIntOps::avgFloorS(const APInt &C1, const APInt &C2) {
3099+
// Return floor((C1 + C2)/2))
3100+
unsigned FullWidth = C1.getBitWidth() + 1;
3101+
APInt C1Ext = C1.sext(FullWidth);
3102+
APInt C2Ext = C2.sext(FullWidth);
3103+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3104+
}
3105+
3106+
APInt APIntOps::avgFloorU(const APInt &C1, const APInt &C2) {
3107+
// Return floor((C1 + C2)/2))
3108+
unsigned FullWidth = C1.getBitWidth() + 1;
3109+
APInt C1Ext = C1.zext(FullWidth);
3110+
APInt C2Ext = C2.zext(FullWidth);
3111+
return (C1Ext + C2Ext).extractBits(C1.getBitWidth(), 1);
3112+
}
3113+
3114+
APInt APIntOps::avgCeilS(const APInt &C1, const APInt &C2) {
3115+
// Return ceil((C1 + C2)/2))
3116+
unsigned FullWidth = C1.getBitWidth() + 1;
3117+
APInt C1Ext = C1.sext(FullWidth);
3118+
APInt C2Ext = C2.sext(FullWidth);
3119+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3120+
}
3121+
3122+
APInt APIntOps::avgCeilU(const APInt &C1, const APInt &C2) {
3123+
// Return ceil((C1 + C2)/2))
3124+
unsigned FullWidth = C1.getBitWidth() + 1;
3125+
APInt C1Ext = C1.zext(FullWidth);
3126+
APInt C2Ext = C2.zext(FullWidth);
3127+
return (C1Ext + C2Ext + 1).extractBits(C1.getBitWidth(), 1);
3128+
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "llvm/Support/Alignment.h"
1515
#include "gtest/gtest.h"
1616
#include <array>
17+
#include <limits.h>
1718
#include <optional>
1819

1920
using namespace llvm;
@@ -2877,6 +2878,76 @@ TEST(APIntTest, RoundingSDiv) {
28772878
}
28782879
}
28792880

2881+
TEST(APIntTest, Average) {
2882+
APInt A0(32, 0);
2883+
APInt A2(32, 2);
2884+
APInt A100(32, 100);
2885+
APInt A101(32, 101);
2886+
APInt A200(32, 200, false);
2887+
APInt ApUMax(32, UINT_MAX, false);
2888+
2889+
EXPECT_EQ(APInt(32, 150), APIntOps::avgFloorU(A100, A200));
2890+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::DOWN),
2891+
APIntOps::avgFloorU(A100, A200));
2892+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A200, A2, APInt::Rounding::UP),
2893+
APIntOps::avgCeilU(A100, A200));
2894+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::DOWN),
2895+
APIntOps::avgFloorU(A100, A101));
2896+
EXPECT_EQ(APIntOps::RoundingUDiv(A100 + A101, A2, APInt::Rounding::UP),
2897+
APIntOps::avgCeilU(A100, A101));
2898+
EXPECT_EQ(A0, APIntOps::avgFloorU(A0, A0));
2899+
EXPECT_EQ(A0, APIntOps::avgCeilU(A0, A0));
2900+
EXPECT_EQ(ApUMax, APIntOps::avgFloorU(ApUMax, ApUMax));
2901+
EXPECT_EQ(ApUMax, APIntOps::avgCeilU(ApUMax, ApUMax));
2902+
2903+
APInt Ap100(32, +100);
2904+
APInt Ap101(32, +101);
2905+
APInt Ap200(32, +200);
2906+
APInt Am1(32, -1);
2907+
APInt Am100(32, -100);
2908+
APInt Am101(32, -101);
2909+
APInt Am200(32, -200);
2910+
APInt AmSMin(32, INT_MIN);
2911+
APInt ApSMax(32, INT_MAX);
2912+
2913+
EXPECT_EQ(APInt(32, +150), APIntOps::avgFloorS(Ap100, Ap200));
2914+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::DOWN),
2915+
APIntOps::avgFloorS(Ap100, Ap200));
2916+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap200, A2, APInt::Rounding::UP),
2917+
APIntOps::avgCeilS(Ap100, Ap200));
2918+
2919+
EXPECT_EQ(APInt(32, -150), APIntOps::avgFloorS(Am100, Am200));
2920+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::DOWN),
2921+
APIntOps::avgFloorS(Am100, Am200));
2922+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am200, A2, APInt::Rounding::UP),
2923+
APIntOps::avgCeilS(Am100, Am200));
2924+
2925+
EXPECT_EQ(APInt(32, +100), APIntOps::avgFloorS(Ap100, Ap101));
2926+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::DOWN),
2927+
APIntOps::avgFloorS(Ap100, Ap101));
2928+
EXPECT_EQ(APInt(32, +101), APIntOps::avgCeilS(Ap100, Ap101));
2929+
EXPECT_EQ(APIntOps::RoundingSDiv(Ap100 + Ap101, A2, APInt::Rounding::UP),
2930+
APIntOps::avgCeilS(Ap100, Ap101));
2931+
2932+
EXPECT_EQ(APInt(32, -101), APIntOps::avgFloorS(Am100, Am101));
2933+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::DOWN),
2934+
APIntOps::avgFloorS(Am100, Am101));
2935+
EXPECT_EQ(APInt(32, -100), APIntOps::avgCeilS(Am100, Am101));
2936+
EXPECT_EQ(APIntOps::RoundingSDiv(Am100 + Am101, A2, APInt::Rounding::UP),
2937+
APIntOps::avgCeilS(Am100, Am101));
2938+
2939+
EXPECT_EQ(AmSMin, APIntOps::avgFloorS(AmSMin, AmSMin));
2940+
EXPECT_EQ(AmSMin, APIntOps::avgCeilS(AmSMin, AmSMin));
2941+
2942+
EXPECT_EQ(A0, APIntOps::avgFloorS(A0, A0));
2943+
EXPECT_EQ(A0, APIntOps::avgCeilS(A0, A0));
2944+
EXPECT_EQ(Am1, APIntOps::avgFloorS(AmSMin, ApSMax));
2945+
EXPECT_EQ(A0, APIntOps::avgCeilS(AmSMin, ApSMax));
2946+
2947+
EXPECT_EQ(ApSMax, APIntOps::avgFloorS(ApSMax, ApSMax));
2948+
EXPECT_EQ(ApSMax, APIntOps::avgCeilS(ApSMax, ApSMax));
2949+
}
2950+
28802951
TEST(APIntTest, umul_ov) {
28812952
const std::pair<uint64_t, uint64_t> Overflows[] = {
28822953
{0x8000000000000000, 2},

0 commit comments

Comments
 (0)