Skip to content

Commit dbc5609

Browse files
authored
Merge pull request #16345 from ziglang/15920
Emit check for memory intrinsics for WebAssembly
2 parents e395a08 + 37e2a04 commit dbc5609

File tree

4 files changed

+127
-3
lines changed

4 files changed

+127
-3
lines changed

src/codegen/llvm.zig

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8511,6 +8511,14 @@ pub const FuncGen = struct {
85118511
const dest_ptr = self.sliceOrArrayPtr(dest_slice, ptr_ty);
85128512
const is_volatile = ptr_ty.isVolatilePtr(mod);
85138513

8514+
// Any WebAssembly runtime will trap when the destination pointer is out-of-bounds, regardless
8515+
// of the length. This means we need to emit a check where we skip the memset when the length
8516+
// is 0 as we allow for undefined pointers in 0-sized slices.
8517+
// This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
8518+
const intrinsic_len0_traps = o.target.isWasm() and
8519+
ptr_ty.isSlice(mod) and
8520+
std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory);
8521+
85148522
if (try self.air.value(bin_op.rhs, mod)) |elem_val| {
85158523
if (elem_val.isUndefDeep(mod)) {
85168524
// Even if safety is disabled, we still emit a memset to undefined since it conveys
@@ -8521,7 +8529,11 @@ pub const FuncGen = struct {
85218529
else
85228530
u8_llvm_ty.getUndef();
85238531
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
8524-
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8532+
if (intrinsic_len0_traps) {
8533+
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8534+
} else {
8535+
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8536+
}
85258537

85268538
if (safety and mod.comp.bin_file.options.valgrind) {
85278539
self.valgrindMarkUndef(dest_ptr, len);
@@ -8539,7 +8551,12 @@ pub const FuncGen = struct {
85398551
.val = byte_val,
85408552
});
85418553
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
8542-
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8554+
8555+
if (intrinsic_len0_traps) {
8556+
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8557+
} else {
8558+
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8559+
}
85438560
return null;
85448561
}
85458562
}
@@ -8551,7 +8568,12 @@ pub const FuncGen = struct {
85518568
// In this case we can take advantage of LLVM's intrinsic.
85528569
const fill_byte = try self.bitCast(value, elem_ty, Type.u8);
85538570
const len = self.sliceOrArrayLenInBytes(dest_slice, ptr_ty);
8554-
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8571+
8572+
if (intrinsic_len0_traps) {
8573+
try self.safeWasmMemset(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8574+
} else {
8575+
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8576+
}
85558577
return null;
85568578
}
85578579

@@ -8622,6 +8644,25 @@ pub const FuncGen = struct {
86228644
return null;
86238645
}
86248646

8647+
fn safeWasmMemset(
8648+
self: *FuncGen,
8649+
dest_ptr: *llvm.Value,
8650+
fill_byte: *llvm.Value,
8651+
len: *llvm.Value,
8652+
dest_ptr_align: u32,
8653+
is_volatile: bool,
8654+
) !void {
8655+
const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
8656+
const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
8657+
const memset_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapSkip");
8658+
const end_block = self.context.appendBasicBlock(self.llvm_func, "MemsetTrapEnd");
8659+
_ = self.builder.buildCondBr(cond, memset_block, end_block);
8660+
self.builder.positionBuilderAtEnd(memset_block);
8661+
_ = self.builder.buildMemSet(dest_ptr, fill_byte, len, dest_ptr_align, is_volatile);
8662+
_ = self.builder.buildBr(end_block);
8663+
self.builder.positionBuilderAtEnd(end_block);
8664+
}
8665+
86258666
fn airMemcpy(self: *FuncGen, inst: Air.Inst.Index) !?*llvm.Value {
86268667
const o = self.dg.object;
86278668
const mod = o.module;
@@ -8634,6 +8675,35 @@ pub const FuncGen = struct {
86348675
const len = self.sliceOrArrayLenInBytes(dest_slice, dest_ptr_ty);
86358676
const dest_ptr = self.sliceOrArrayPtr(dest_slice, dest_ptr_ty);
86368677
const is_volatile = src_ptr_ty.isVolatilePtr(mod) or dest_ptr_ty.isVolatilePtr(mod);
8678+
8679+
// When bulk-memory is enabled, this will be lowered to WebAssembly's memory.copy instruction.
8680+
// This instruction will trap on an invalid address, regardless of the length.
8681+
// For this reason we must add a check for 0-sized slices as its pointer field can be undefined.
8682+
// We only have to do this for slices as arrays will have a valid pointer.
8683+
// This logic can be removed once https://github.com/ziglang/zig/issues/16360 is done.
8684+
if (o.target.isWasm() and
8685+
std.Target.wasm.featureSetHas(o.target.cpu.features, .bulk_memory) and
8686+
dest_ptr_ty.isSlice(mod))
8687+
{
8688+
const llvm_usize_ty = self.context.intType(self.dg.object.target.ptrBitWidth());
8689+
const cond = try self.cmp(len, llvm_usize_ty.constInt(0, .False), Type.usize, .neq);
8690+
const memcpy_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapSkip");
8691+
const end_block = self.context.appendBasicBlock(self.llvm_func, "MemcpyTrapEnd");
8692+
_ = self.builder.buildCondBr(cond, memcpy_block, end_block);
8693+
self.builder.positionBuilderAtEnd(memcpy_block);
8694+
_ = self.builder.buildMemCpy(
8695+
dest_ptr,
8696+
dest_ptr_ty.ptrAlignment(mod),
8697+
src_ptr,
8698+
src_ptr_ty.ptrAlignment(mod),
8699+
len,
8700+
is_volatile,
8701+
);
8702+
_ = self.builder.buildBr(end_block);
8703+
self.builder.positionBuilderAtEnd(end_block);
8704+
return null;
8705+
}
8706+
86378707
_ = self.builder.buildMemCpy(
86388708
dest_ptr,
86398709
dest_ptr_ty.ptrAlignment(mod),

test/standalone.zig

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,10 @@ pub const build_cases = [_]BuildCase{
230230
.build_root = "test/standalone/cmakedefine",
231231
.import = @import("standalone/cmakedefine/build.zig"),
232232
},
233+
.{
234+
.build_root = "test/standalone/zerolength_check",
235+
.import = @import("standalone/zerolength_check/build.zig"),
236+
},
233237
};
234238

235239
const std = @import("std");
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
const std = @import("std");
2+
3+
pub fn build(b: *std.Build) void {
4+
const test_step = b.step("test", "Test it");
5+
b.default_step = test_step;
6+
7+
add(b, test_step, .Debug);
8+
add(b, test_step, .ReleaseFast);
9+
add(b, test_step, .ReleaseSmall);
10+
add(b, test_step, .ReleaseSafe);
11+
}
12+
13+
fn add(b: *std.Build, test_step: *std.Build.Step, optimize: std.builtin.OptimizeMode) void {
14+
const unit_tests = b.addTest(.{
15+
.root_source_file = .{ .path = "src/main.zig" },
16+
.target = .{
17+
.os_tag = .wasi,
18+
.cpu_arch = .wasm32,
19+
.cpu_features_add = std.Target.wasm.featureSet(&.{.bulk_memory}),
20+
},
21+
.optimize = optimize,
22+
});
23+
24+
const run_unit_tests = b.addRunArtifact(unit_tests);
25+
run_unit_tests.skip_foreign_checks = true;
26+
test_step.dependOn(&run_unit_tests.step);
27+
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
const std = @import("std");
2+
3+
test {
4+
var dest = foo();
5+
var source = foo();
6+
7+
@memcpy(dest, source);
8+
@memset(dest, 4);
9+
@memset(dest, undefined);
10+
11+
var dest2 = foo2();
12+
@memset(dest2, 0);
13+
}
14+
15+
fn foo() []u8 {
16+
const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
17+
return @as([*]align(1) u8, @ptrFromInt(ptr))[0..0];
18+
}
19+
20+
fn foo2() []u64 {
21+
const ptr = comptime std.mem.alignBackward(usize, std.math.maxInt(usize), 1);
22+
return @as([*]align(1) u64, @ptrFromInt(ptr))[0..0];
23+
}

0 commit comments

Comments
 (0)