Skip to content

Commit 50e94ee

Browse files
authored
Merge pull request rust-lang#47 from cpdt/feature/45-vector-math-ops
Support vector types in math ops, fixes rust-lang#45
2 parents ce4519b + 8cbc981 commit 50e94ee

File tree

12 files changed

+361
-141
lines changed

12 files changed

+361
-141
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,8 @@ fn jit_compile_sum(
106106
let y = function.get_nth_param(1)?.into_int_value();
107107
let z = function.get_nth_param(2)?.into_int_value();
108108

109-
let sum = builder.build_int_add(&x, &y, "sum");
110-
let sum = builder.build_int_add(&sum, &z, "sum");
109+
let sum = builder.build_int_add(x, y, "sum");
110+
let sum = builder.build_int_add(sum, z, "sum");
111111

112112
builder.build_return(Some(&sum));
113113

examples/jit.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ fn jit_compile_sum(
3232
let y = function.get_nth_param(1)?.into_int_value();
3333
let z = function.get_nth_param(2)?.into_int_value();
3434

35-
let sum = builder.build_int_add(&x, &y, "sum");
36-
let sum = builder.build_int_add(&sum, &z, "sum");
35+
let sum = builder.build_int_add(x, y, "sum");
36+
let sum = builder.build_int_add(sum, z, "sum");
3737

3838
builder.build_return(Some(&sum));
3939

examples/kaleidoscope/main.rs

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -941,19 +941,19 @@ impl<'a> Compiler<'a> {
941941
let rhs = self.compile_expr(right)?;
942942

943943
match op {
944-
'+' => Ok(self.builder.build_float_add(&lhs, &rhs, "tmpadd")),
945-
'-' => Ok(self.builder.build_float_sub(&lhs, &rhs, "tmpsub")),
946-
'*' => Ok(self.builder.build_float_mul(&lhs, &rhs, "tmpmul")),
947-
'/' => Ok(self.builder.build_float_div(&lhs, &rhs, "tmpdiv")),
944+
'+' => Ok(self.builder.build_float_add(lhs, rhs, "tmpadd")),
945+
'-' => Ok(self.builder.build_float_sub(lhs, rhs, "tmpsub")),
946+
'*' => Ok(self.builder.build_float_mul(lhs, rhs, "tmpmul")),
947+
'/' => Ok(self.builder.build_float_div(lhs, rhs, "tmpdiv")),
948948
'<' => Ok({
949-
let cmp = self.builder.build_float_compare(FloatPredicate::ULT, &lhs, &rhs, "tmpcmp");
949+
let cmp = self.builder.build_float_compare(FloatPredicate::ULT, lhs, rhs, "tmpcmp");
950950

951-
self.builder.build_unsigned_int_to_float(&cmp, &self.context.f64_type(), "tmpbool")
951+
self.builder.build_unsigned_int_to_float(cmp, self.context.f64_type(), "tmpbool")
952952
}),
953953
'>' => Ok({
954-
let cmp = self.builder.build_float_compare(FloatPredicate::ULT, &rhs, &lhs, "tmpcmp");
954+
let cmp = self.builder.build_float_compare(FloatPredicate::ULT, rhs, lhs, "tmpcmp");
955955

956-
self.builder.build_unsigned_int_to_float(&cmp, &self.context.f64_type(), "tmpbool")
956+
self.builder.build_unsigned_int_to_float(cmp, self.context.f64_type(), "tmpbool")
957957
}),
958958

959959
custom => {
@@ -1002,7 +1002,7 @@ impl<'a> Compiler<'a> {
10021002

10031003
// create condition by comparing without 0.0 and returning an int
10041004
let cond = self.compile_expr(cond)?;
1005-
let cond = self.builder.build_float_compare(FloatPredicate::ONE, &cond, &zero_const, "ifcond");
1005+
let cond = self.builder.build_float_compare(FloatPredicate::ONE, cond, zero_const, "ifcond");
10061006

10071007
// build branch
10081008
let then_bb = self.context.append_basic_block(&parent, "then");
@@ -1069,11 +1069,11 @@ impl<'a> Compiler<'a> {
10691069
let end_cond = self.compile_expr(end)?;
10701070

10711071
let curr_var = self.builder.build_load(&start_alloca, var_name);
1072-
let next_var = self.builder.build_float_add(curr_var.as_float_value(), &step, "nextvar");
1072+
let next_var = self.builder.build_float_add(curr_var.into_float_value(), step, "nextvar");
10731073

10741074
self.builder.build_store(&start_alloca, &next_var);
10751075

1076-
let end_cond = self.builder.build_float_compare(FloatPredicate::ONE, &end_cond, &self.context.f64_type().const_float(0.0), "loopcond");
1076+
let end_cond = self.builder.build_float_compare(FloatPredicate::ONE, end_cond, self.context.f64_type().const_float(0.0), "loopcond");
10771077
let after_bb = self.context.append_basic_block(&parent, "afterloop");
10781078

10791079
self.builder.build_conditional_branch(&end_cond, &loop_bb, &after_bb);

src/builder.rs

Lines changed: 108 additions & 104 deletions
Large diffs are not rendered by default.

src/types/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ pub use types::fn_type::FunctionType;
1616
pub use types::int_type::IntType;
1717
pub use types::ptr_type::PointerType;
1818
pub use types::struct_type::StructType;
19-
pub use types::traits::{AnyType, BasicType};
19+
pub use types::traits::{AnyType, BasicType, IntMathType, FloatMathType, PointerMathType};
2020
pub use types::vec_type::VectorType;
2121
pub use types::void_type::VoidType;
2222
pub(crate) use types::traits::AsTypeRef;

src/types/traits.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use std::fmt::Debug;
44

55
use types::{IntType, FunctionType, FloatType, PointerType, StructType, ArrayType, VectorType, VoidType, Type};
66
use types::enums::{AnyTypeEnum, BasicTypeEnum};
7+
use values::{IntMathValue, FloatMathValue, PointerMathValue, IntValue, FloatValue, PointerValue, VectorValue};
78

89
// This is an ugly privacy hack so that Type can stay private to this module
910
// and so that super traits using this trait will be not be implementable
@@ -40,5 +41,56 @@ pub trait BasicType: AnyType {
4041
}
4142
}
4243

44+
/// Represents an LLVM type that can have integer math operations applied to it.
45+
pub trait IntMathType: BasicType {
46+
type ValueType: IntMathValue;
47+
type MathConvType: FloatMathType;
48+
type PtrConvType: PointerMathType;
49+
}
50+
51+
/// Represents an LLVM type that can have floating point math operations applied to it.
52+
pub trait FloatMathType: BasicType {
53+
type ValueType: FloatMathValue;
54+
type MathConvType: IntMathType;
55+
}
56+
57+
/// Represents an LLVM type that can have pointer operations applied to it.
58+
pub trait PointerMathType: BasicType {
59+
type ValueType: PointerMathValue;
60+
type PtrConvType: IntMathType;
61+
}
62+
4363
trait_type_set! {AnyType: AnyTypeEnum, BasicTypeEnum, IntType, FunctionType, FloatType, PointerType, StructType, ArrayType, VoidType, VectorType}
4464
trait_type_set! {BasicType: BasicTypeEnum, IntType, FloatType, PointerType, StructType, ArrayType, VectorType}
65+
66+
impl IntMathType for IntType {
67+
type ValueType = IntValue;
68+
type MathConvType = FloatType;
69+
type PtrConvType = PointerType;
70+
}
71+
72+
impl IntMathType for VectorType {
73+
type ValueType = VectorValue;
74+
type MathConvType = VectorType;
75+
type PtrConvType = VectorType;
76+
}
77+
78+
impl FloatMathType for FloatType {
79+
type ValueType = FloatValue;
80+
type MathConvType = IntType;
81+
}
82+
83+
impl FloatMathType for VectorType {
84+
type ValueType = VectorValue;
85+
type MathConvType = VectorType;
86+
}
87+
88+
impl PointerMathType for PointerType {
89+
type ValueType = PointerValue;
90+
type PtrConvType = IntType;
91+
}
92+
93+
impl PointerMathType for VectorType {
94+
type ValueType = VectorValue;
95+
type PtrConvType = VectorType;
96+
}

src/values/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ pub use values::metadata_value::{MetadataValue, FIRST_CUSTOM_METADATA_KIND_ID};
2525
pub use values::phi_value::PhiValue;
2626
pub use values::ptr_value::PointerValue;
2727
pub use values::struct_value::StructValue;
28-
pub use values::traits::{AnyValue, AggregateValue, BasicValue};
28+
pub use values::traits::{AnyValue, AggregateValue, BasicValue, IntMathValue, FloatMathValue, PointerMathValue};
2929
pub use values::vec_value::VectorValue;
3030
pub(crate) use values::traits::AsValueRef;
3131

src/values/traits.rs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use llvm_sys::prelude::LLVMValueRef;
33
use std::fmt::Debug;
44

55
use values::{ArrayValue, AggregateValueEnum, GlobalValue, StructValue, BasicValueEnum, AnyValueEnum, IntValue, FloatValue, PointerValue, PhiValue, VectorValue, FunctionValue, InstructionValue};
6+
use types::{IntMathType, FloatMathType, PointerMathType, IntType, FloatType, PointerType, VectorType};
67

78
// This is an ugly privacy hack so that Type can stay private to this module
89
// and so that super traits using this trait will be not be implementable
@@ -22,6 +23,19 @@ macro_rules! trait_value_set {
2223
);
2324
}
2425

26+
macro_rules! math_trait_value_set {
27+
($trait_name:ident: $(($value_type:ident => $base_type:ident)),*) => (
28+
$(
29+
impl $trait_name for $value_type {
30+
type BaseType = $base_type;
31+
fn new(value: LLVMValueRef) -> Self {
32+
$value_type::new(value)
33+
}
34+
}
35+
)*
36+
)
37+
}
38+
2539
/// Represents an aggregate value, built on top of other values.
2640
pub trait AggregateValue: BasicValue {
2741
/// Returns an enum containing a typed version of the `AggregateValue`.
@@ -38,6 +52,23 @@ pub trait BasicValue: AnyValue {
3852
}
3953
}
4054

55+
/// Represents a value which is permitted in integer math operations
56+
pub trait IntMathValue: BasicValue {
57+
type BaseType: IntMathType;
58+
fn new(value: LLVMValueRef) -> Self;
59+
}
60+
61+
/// Represents a value which is permitted in floating point math operations
62+
pub trait FloatMathValue: BasicValue {
63+
type BaseType: FloatMathType;
64+
fn new(value: LLVMValueRef) -> Self;
65+
}
66+
67+
pub trait PointerMathValue: BasicValue {
68+
type BaseType: PointerMathType;
69+
fn new(value: LLVMValueRef) -> Self;
70+
}
71+
4172
/// Defines any struct wrapping an LLVM value.
4273
pub trait AnyValue: AsValueRef + Debug {
4374
/// Returns an enum containing a typed version of `AnyValue`.
@@ -49,3 +80,6 @@ pub trait AnyValue: AsValueRef + Debug {
4980
trait_value_set! {AggregateValue: ArrayValue, AggregateValueEnum, StructValue}
5081
trait_value_set! {AnyValue: AnyValueEnum, BasicValueEnum, AggregateValueEnum, ArrayValue, IntValue, FloatValue, GlobalValue, PhiValue, PointerValue, FunctionValue, StructValue, VectorValue, InstructionValue}
5182
trait_value_set! {BasicValue: ArrayValue, BasicValueEnum, AggregateValueEnum, IntValue, FloatValue, GlobalValue, StructValue, PointerValue, VectorValue}
83+
math_trait_value_set! {IntMathValue: (IntValue => IntType), (VectorValue => VectorType)}
84+
math_trait_value_set! {FloatMathValue: (FloatValue => FloatType), (VectorValue => VectorType)}
85+
math_trait_value_set! {PointerMathValue: (PointerValue => PointerType), (VectorValue => VectorType)}

0 commit comments

Comments
 (0)