Skip to content

Commit 8cbc981

Browse files
committed
Added tests for vector functions on Builder
1 parent 9fd5cd1 commit 8cbc981

File tree

1 file changed

+130
-0
lines changed

1 file changed

+130
-0
lines changed

tests/test_builder.rs

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use self::inkwell::context::Context;
55
use self::inkwell::builder::Builder;
66
use self::inkwell::targets::{InitializationConfig, Target};
77
use self::inkwell::execution_engine::Symbol;
8+
use self::inkwell::types::BasicType;
89

910
use std::ffi::CString;
1011
use std::ptr::null;
@@ -447,3 +448,132 @@ fn test_no_builder_double_free2() {
447448
// 2nd Context drops fine
448449
// Builder drops fine
449450
}
451+
452+
#[test]
453+
fn test_vector_convert_ops() {
454+
let context = Context::create();
455+
let module = context.create_module("test");
456+
let int8_vec_type = context.i8_type().vec_type(3);
457+
let int32_vec_type = context.i32_type().vec_type(3);
458+
let float32_vec_type = context.f32_type().vec_type(3);
459+
let float16_vec_type = context.f16_type().vec_type(3);
460+
461+
// Here we're building a function that takes in a <3 x i8> and returns it casted to and from a <3 x i32>
462+
// Casting to and from means we can ensure the cast build functions return a vector when one is provided.
463+
let fn_type = int32_vec_type.fn_type(&[&int8_vec_type], false);
464+
let fn_value = module.add_function("test_int_vec_cast", &fn_type, None);
465+
let entry = fn_value.append_basic_block("entry");
466+
let builder = context.create_builder();
467+
468+
builder.position_at_end(&entry);
469+
let in_vec = fn_value.get_first_param().unwrap().into_vector_value();
470+
let casted_vec = builder.build_int_cast(in_vec, int32_vec_type, "casted_vec");
471+
let uncasted_vec = builder.build_int_cast(casted_vec, int8_vec_type, "uncasted_vec");
472+
builder.build_return(Some(&casted_vec));
473+
assert!(fn_value.verify(true));
474+
475+
// Here we're building a function that takes in a <3 x f32> and returns it casted to and from a <3 x f16>
476+
let fn_type = float16_vec_type.fn_type(&[&float32_vec_type], false);
477+
let fn_value = module.add_function("test_float_vec_cast", &fn_type, None);
478+
let entry = fn_value.append_basic_block("entry");
479+
let builder = context.create_builder();
480+
481+
builder.position_at_end(&entry);
482+
let in_vec = fn_value.get_first_param().unwrap().into_vector_value();
483+
let casted_vec = builder.build_float_cast(in_vec, float16_vec_type, "casted_vec");
484+
let uncasted_vec = builder.build_float_cast(casted_vec, float32_vec_type, "uncasted_vec");
485+
builder.build_return(Some(&casted_vec));
486+
assert!(fn_value.verify(true));
487+
488+
// Here we're building a function that takes in a <3 x f32> and returns it casted to and from a <3 x i32>
489+
let fn_type = int32_vec_type.fn_type(&[&float32_vec_type], false);
490+
let fn_value = module.add_function("test_float_to_int_vec_cast", &fn_type, None);
491+
let entry = fn_value.append_basic_block("entry");
492+
let builder = context.create_builder();
493+
494+
builder.position_at_end(&entry);
495+
let in_vec = fn_value.get_first_param().unwrap().into_vector_value();
496+
let casted_vec = builder.build_float_to_signed_int(in_vec, int32_vec_type, "casted_vec");
497+
let uncasted_vec = builder.build_signed_int_to_float(casted_vec, float32_vec_type, "uncasted_vec");
498+
builder.build_return(Some(&casted_vec));
499+
assert!(fn_value.verify(true));
500+
}
501+
502+
#[test]
503+
fn test_vector_binary_ops() {
504+
let context = Context::create();
505+
let module = context.create_module("test");
506+
let int32_vec_type = context.i32_type().vec_type(2);
507+
let float32_vec_type = context.f32_type().vec_type(2);
508+
let bool_vec_type = context.bool_type().vec_type(2);
509+
510+
// Here we're building a function that takes in three <2 x i32>s and returns them added together as a <2 x i32>
511+
let fn_type = int32_vec_type.fn_type(&[&int32_vec_type, &int32_vec_type, &int32_vec_type], false);
512+
let fn_value = module.add_function("test_int_vec_add", &fn_type, None);
513+
let entry = fn_value.append_basic_block("entry");
514+
let builder = context.create_builder();
515+
516+
builder.position_at_end(&entry);
517+
let p1_vec = fn_value.get_first_param().unwrap().into_vector_value();
518+
let p2_vec = fn_value.get_nth_param(1).unwrap().into_vector_value();
519+
let p3_vec = fn_value.get_nth_param(2).unwrap().into_vector_value();
520+
let added_vec = builder.build_int_add(p1_vec, p2_vec, "added_vec");
521+
let added_vec = builder.build_int_add(added_vec, p3_vec, "added_vec");
522+
builder.build_return(Some(&added_vec));
523+
assert!(fn_value.verify(true));
524+
525+
// Here we're building a function that takes in three <2 x f32>s and returns x * y / z as an
526+
// <2 x f32>
527+
let fn_type = float32_vec_type.fn_type(&[&float32_vec_type, &float32_vec_type, &float32_vec_type], false);
528+
let fn_value = module.add_function("test_float_vec_mul", &fn_type, None);
529+
let entry = fn_value.append_basic_block("entry");
530+
let builder = context.create_builder();
531+
532+
builder.position_at_end(&entry);
533+
let p1_vec = fn_value.get_first_param().unwrap().into_vector_value();
534+
let p2_vec = fn_value.get_nth_param(1).unwrap().into_vector_value();
535+
let p3_vec = fn_value.get_nth_param(2).unwrap().into_vector_value();
536+
let multiplied_vec = builder.build_float_mul(p1_vec, p2_vec, "multipled_vec");
537+
let divided_vec = builder.build_float_div(multiplied_vec, p3_vec, "divided_vec");
538+
builder.build_return(Some(&divided_vec));
539+
assert!(fn_value.verify(true));
540+
541+
// Here we're building a function that takes two <2 x f32>s and a <2 x bool> and returns (x < y) * z
542+
// as a <2 x bool>
543+
let fn_type = bool_vec_type.fn_type(&[&float32_vec_type, &float32_vec_type, &bool_vec_type], false);
544+
let fn_value = module.add_function("test_float_vec_compare", &fn_type, None);
545+
let entry = fn_value.append_basic_block("entry");
546+
let builder = context.create_builder();
547+
548+
builder.position_at_end(&entry);
549+
let p1_vec = fn_value.get_first_param().unwrap().into_vector_value();
550+
let p2_vec = fn_value.get_nth_param(1).unwrap().into_vector_value();
551+
let p3_vec = fn_value.get_nth_param(2).unwrap().into_vector_value();
552+
let compared_vec = builder.build_float_compare(self::inkwell::FloatPredicate::OLT, p1_vec, p2_vec, "compared_vec");
553+
let multiplied_vec = builder.build_int_mul(compared_vec, p3_vec, "multiplied_vec");
554+
builder.build_return(Some(&multiplied_vec));
555+
assert!(fn_value.verify(true));
556+
}
557+
558+
#[test]
559+
fn test_vector_pointer_ops() {
560+
let context = Context::create();
561+
let module = context.create_module("test");
562+
let int32_vec_type = context.i32_type().vec_type(4);
563+
let i8_ptr_vec_type = context.i8_type().ptr_type(AddressSpace::Generic).vec_type(4);
564+
let bool_vec_type = context.bool_type().vec_type(4);
565+
566+
// Here we're building a function that takes a <4 x i32>, converts it to a <4 x i8*> and returns a
567+
// <4 x bool> if the pointer is null
568+
let fn_type = bool_vec_type.fn_type(&[&int32_vec_type], false);
569+
let fn_value = module.add_function("test_ptr_null", &fn_type, None);
570+
let entry = fn_value.append_basic_block("entry");
571+
let builder = context.create_builder();
572+
573+
builder.position_at_end(&entry);
574+
let in_vec = fn_value.get_first_param().unwrap().into_vector_value();
575+
let ptr_vec = builder.build_int_to_ptr(in_vec, i8_ptr_vec_type, "ptr_vec");
576+
let is_null_vec = builder.build_is_null(ptr_vec, "is_null_vec");
577+
builder.build_return(Some(&is_null_vec));
578+
assert!(fn_value.verify(true));
579+
}

0 commit comments

Comments
 (0)