Skip to content

Commit dc7d200

Browse files
author
Github Executorch
committed
Update base for Update on "Integrate torchgen exception boundary with ExecuTorch"
As of #7746, we build with exceptions by default, so we just need to use them. Differential Revision: [D67904052](https://our.internmc.facebook.com/intern/diff/D67904052/) [ghstack-poisoned]
2 parents 8ead435 + 832f855 commit dc7d200

File tree

22 files changed

+292
-115
lines changed

22 files changed

+292
-115
lines changed

backends/cadence/aot/compiler.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ def export_to_executorch_gen_etrecord(
264264
alloc_graph_output: bool = True,
265265
memory_config: Optional[MemoryConfig] = None,
266266
dump_graphs: bool = False,
267+
mem_alignment: int = 1,
267268
) -> ExecutorchProgramManager:
268269
cadence_passes = get_cadence_passes(opt_level)
269270
edge_prog_manager = export_to_edge(model, inputs, dump_graphs)
@@ -290,6 +291,7 @@ def export_to_executorch_gen_etrecord(
290291
mem_algo=mem_algo,
291292
alloc_graph_input=alloc_graph_input,
292293
alloc_graph_output=alloc_graph_output,
294+
mem_alignment=mem_alignment,
293295
)
294296

295297
# Get executorch program after Cadence specific passes

backends/cadence/aot/memory_planning.py

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import collections
1010
import itertools
1111
import logging
12+
import math
1213
import typing
1314
from functools import partial
1415
from typing import Iterable, List, Optional, Tuple
@@ -39,6 +40,10 @@ def get_size(memory_config: MemoryConfig, exir_id: int) -> int:
3940
return memory_config.memory_sizes[exir_id - 1]
4041

4142

43+
def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int:
44+
return int(math.ceil(pre_aligned_offset / alignment) * alignment)
45+
46+
4247
def collect_specs_from_graph_module(
4348
graph_module: torch.fx.GraphModule,
4449
alloc_graph_input: bool,
@@ -95,9 +100,9 @@ def overlap(spec: TensorSpec) -> Optional[TensorSpec]:
95100
return None
96101

97102
def memory_available(spec: TensorSpec) -> bool:
98-
return spec.mem_offset + spec.allocated_memory <= get_size(
99-
memory_config, spec.mem_id
100-
)
103+
return get_aligned_offset(
104+
spec.mem_offset + spec.allocated_memory, alignment
105+
) <= get_size(memory_config, spec.mem_id)
101106

102107
# Iterate over all the specs in sorted order
103108
for spec in sorted(
@@ -116,7 +121,9 @@ def memory_available(spec: TensorSpec) -> bool:
116121
continue
117122
spec.mem_offset = 0
118123
while memory_available(spec) and (overlapped := overlap(spec)):
119-
spec.mem_offset = overlapped.mem_offset + overlapped.allocated_memory
124+
spec.mem_offset = get_aligned_offset(
125+
overlapped.mem_offset + overlapped.allocated_memory, alignment
126+
)
120127
if memory_available(spec):
121128
allocated_buffers[spec.mem_id].append(spec)
122129
bufsizes[spec.mem_id] = max(
@@ -202,13 +209,16 @@ def greedy_by_size_for_offset_calculation_with_hierarchy(
202209
# calculation of gap incorrect. Moving it out will make the algorithm degenerate
203210
# to the naive one, reusing 0 tensor. The paper may have a typo here.
204211
prev_offset = max(
205-
allocated_spec.mem_offset + allocated_spec.allocated_memory,
212+
get_aligned_offset(
213+
allocated_spec.mem_offset + allocated_spec.allocated_memory,
214+
alignment,
215+
),
206216
prev_offset,
207217
)
208218
if spec.mem_offset is None:
209-
if prev_offset + spec.allocated_memory > get_size(
210-
memory_config, spec.mem_id
211-
):
219+
if get_aligned_offset(
220+
prev_offset + spec.allocated_memory, alignment
221+
) > get_size(memory_config, spec.mem_id):
212222
continue
213223
else:
214224
spec.mem_offset = prev_offset
@@ -423,6 +433,7 @@ def __init__(
423433
]
424434
]
425435
] = None,
436+
mem_alignment: int = 1,
426437
) -> None:
427438
self._init_mem_algos()
428439

@@ -433,6 +444,9 @@ def __init__(
433444
self.alloc_graph_output = alloc_graph_output
434445
self.additional_constraint_gen_passes = additional_constraint_gen_passes
435446

447+
assert mem_alignment > 0, "mem_alignment must be positive"
448+
self.mem_alignment = mem_alignment
449+
436450
def _init_mem_algos(self) -> None:
437451
self.available_mem_algos = [
438452
position_based_greedy_with_hierarchy,
@@ -459,6 +473,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
459473
allow_lifetime_and_storage_overlap=(self.opt_level >= 2),
460474
alloc_graph_input=self.alloc_graph_input,
461475
alloc_graph_output=self.alloc_graph_output,
476+
alignment=self.mem_alignment,
462477
)
463478
mem_planning(graph_module)
464479

backends/cadence/aot/tests/test_memory_passes.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@
1414
from executorch.backends.cadence.aot.pass_utils import count_node
1515
from executorch.exir import memory
1616
from executorch.exir.dialects._ops import ops as exir_ops
17+
from executorch.exir.memory_planning import collect_specs_from_nodes
1718
from executorch.exir.tests.models import MultiLayerPerceptron
1819

1920

2021
class TestMemPlanningPasses(unittest.TestCase):
21-
def test_calculate_peak_memory_pass(self):
22+
def test_calculate_peak_memory_pass(self) -> None:
2223
class PeakMemoryTestModel(torch.nn.Module):
2324
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
2425
super().__init__()
@@ -32,7 +33,7 @@ def forward(self, x: torch.Tensor):
3233
x = self.linear2(x)
3334
return x
3435

35-
def calculate_aligned_num_bytes(num: int, alignment: int = 16):
36+
def calculate_aligned_num_bytes(num: int, alignment: int = 16) -> int:
3637
return math.ceil(num / alignment) * alignment
3738

3839
# model 1
@@ -86,7 +87,7 @@ def calculate_aligned_num_bytes(num: int, alignment: int = 16):
8687
) # Align data on a 16 byte boundary
8788
self.assertEqual(peak_usage, expected_peak_usage)
8889

89-
def test_zero_memory_pass(self):
90+
def test_zero_memory_pass(self) -> None:
9091
class ZeroMem(torch.nn.Module):
9192
def forward(self, x):
9293
return x[:, 2::3, ...]
@@ -188,7 +189,7 @@ def _verify_select_nop_memory_alloc(self, node: torch.fx.Node) -> None:
188189
f"{spec=} {arg_spec=}",
189190
)
190191

191-
def verify_nop_memory_alloc(self, graph_module):
192+
def verify_nop_memory_alloc(self, graph_module: torch.fx.GraphModule) -> None:
192193
for node in graph_module.graph.find_nodes(
193194
op="call_function", target=torch.ops.aten._cat_nop.out
194195
):
@@ -204,7 +205,7 @@ def verify_nop_memory_alloc(self, graph_module):
204205
):
205206
self._verify_select_nop_memory_alloc(node)
206207

207-
def test_optimize_cat_on_placeholders(self):
208+
def test_optimize_cat_on_placeholders(self) -> None:
208209
class Cat(torch.nn.Module):
209210
def forward(self, x, y):
210211
return torch.ops.aten.cat((x, y))
@@ -228,7 +229,7 @@ def forward(self, x, y):
228229
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
229230
self.verify_nop_memory_alloc(graph_module)
230231

231-
def test_optimize_cat_outermost(self):
232+
def test_optimize_cat_outermost(self) -> None:
232233
class OptimizeCatFeasible1(torch.nn.Module):
233234
def forward(self, x, y):
234235
x1 = torch.add(x, 2.4, 3.1)
@@ -255,7 +256,7 @@ def forward(self, x, y):
255256
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
256257
self.verify_nop_memory_alloc(graph_module)
257258

258-
def test_optimize_cat_non_outermost(self):
259+
def test_optimize_cat_non_outermost(self) -> None:
259260
class OptimizeCatFeasible2(torch.nn.Module):
260261
def forward(self, x, y):
261262
x1 = torch.add(x, 2.4, 3.1)
@@ -282,7 +283,7 @@ def forward(self, x, y):
282283
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
283284
self.verify_nop_memory_alloc(graph_module)
284285

285-
def test_no_optimize_cat_non_outermost(self):
286+
def test_no_optimize_cat_non_outermost(self) -> None:
286287
class OptimizeCatInfeasible1(torch.nn.Module):
287288
def forward(self, x, y):
288289
x1 = torch.add(x, 2.4, 3.1)
@@ -308,7 +309,7 @@ def forward(self, x, y):
308309
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
309310
self.verify_nop_memory_alloc(graph_module)
310311

311-
def test_no_optimize_cat_non_outermost1(self):
312+
def test_no_optimize_cat_non_outermost1(self) -> None:
312313
class OptimizeCatInfeasible2(torch.nn.Module):
313314
def forward(self, x, y):
314315
x1 = torch.add(x, 2.4, 3.1)
@@ -335,7 +336,7 @@ def forward(self, x, y):
335336
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
336337
self.verify_nop_memory_alloc(graph_module)
337338

338-
def test_optimize_cat_with_slice(self):
339+
def test_optimize_cat_with_slice(self) -> None:
339340
class OptimizeCatSliceFeasible(torch.nn.Module):
340341
def forward(self, x):
341342
x1 = torch.add(x, 2.4, 3.1)
@@ -364,7 +365,7 @@ def forward(self, x):
364365
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
365366
self.verify_nop_memory_alloc(graph_module)
366367

367-
def test_optimize_cat_with_slice_infeasible(self):
368+
def test_optimize_cat_with_slice_infeasible(self) -> None:
368369
class OptimizeCatSliceInfeasible(torch.nn.Module):
369370
def forward(self, x, y):
370371
x1 = torch.add(x, 2.4, 3.1)
@@ -390,7 +391,7 @@ def forward(self, x, y):
390391
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 1)
391392
self.verify_nop_memory_alloc(graph_module)
392393

393-
def test_optimize_slice_Tensor(self):
394+
def test_optimize_slice_Tensor(self) -> None:
394395
class SliceTensor(torch.nn.Module):
395396
def forward(self, x, y, z):
396397
x1 = torch.add(x, 2.4, 3.1)
@@ -452,7 +453,7 @@ def forward(self, x, y, z):
452453
)
453454
self.verify_nop_memory_alloc(graph_module)
454455

455-
def test_optimize_select_Tensor(self):
456+
def test_optimize_select_Tensor(self) -> None:
456457
class SelectTensor(torch.nn.Module):
457458
def forward(self, x, y, z):
458459
x1 = torch.add(x, 2.4, 3.1)
@@ -519,7 +520,7 @@ def forward(self, x, y, z):
519520

520521
# TODO: Test fails due to memory planning
521522
@unittest.expectedFailure
522-
def test_optimize_cat_with_param(self):
523+
def test_optimize_cat_with_param(self) -> None:
523524
class CatWithPadding(torch.nn.Module):
524525
def __init__(self, padding_shape):
525526
super().__init__()
@@ -547,7 +548,7 @@ def forward(self, x, y):
547548
self.assertEqual(count_node(graph_module, exir_ops.edge.aten.cat.default), 1)
548549
self.verify_nop_memory_alloc(graph_module)
549550

550-
def test_optimize_cat_then_slice_on_mutable_buffer(self):
551+
def test_optimize_cat_then_slice_on_mutable_buffer(self) -> None:
551552
class CatWithPadding(torch.nn.Module):
552553
def __init__(self, padding_shape):
553554
super().__init__()
@@ -572,7 +573,7 @@ def forward(self, x, y):
572573
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 1)
573574
self.verify_nop_memory_alloc(graph_module)
574575

575-
def test_optimize_cat_with_view(self):
576+
def test_optimize_cat_with_view(self) -> None:
576577
class CatViewFeasible(torch.nn.Module):
577578
def forward(self, x, y):
578579
x1 = torch.add(x, 2.4, 3.1)
@@ -599,7 +600,7 @@ def forward(self, x, y):
599600
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
600601
self.verify_nop_memory_alloc(graph_module)
601602

602-
def test_no_optimize_cat_with_repeated_args(self):
603+
def test_no_optimize_cat_with_repeated_args(self) -> None:
603604
class CatViewInfeasible(torch.nn.Module):
604605
def forward(self, x):
605606
x1 = torch.add(x, 2.4, 3.1)
@@ -623,7 +624,7 @@ def forward(self, x):
623624
self.assertEqual(count_node(graph_module, torch.ops.aten._cat_nop.out), 0)
624625
self.verify_nop_memory_alloc(graph_module)
625626

626-
def test_no_optimize_cat_with_placeholder(self):
627+
def test_no_optimize_cat_with_placeholder(self) -> None:
627628
class CatViewInfeasible(torch.nn.Module):
628629
def forward(self, x, y):
629630
# Repeat will be decomposed into a cat. The cat cannot be optimized
@@ -741,7 +742,7 @@ def forward(self, x) -> torch.Tensor:
741742
self.assertEqual(count_node(graph_module, torch.ops.aten.cat.out), 0)
742743
self.verify_nop_memory_alloc(graph_module)
743744

744-
def test_view_for_unallocated_output(self):
745+
def test_view_for_unallocated_output(self) -> None:
745746
class Model(torch.nn.Module):
746747
def __init__(self, padding_shape):
747748
super().__init__()
@@ -764,3 +765,40 @@ def forward(self, x, y):
764765
)
765766
self.assertEqual(count_node(graph_module, memory.view), 1)
766767
self.verify_nop_memory_alloc(graph_module)
768+
769+
def test_start_alignment_constraints(self) -> None:
770+
class Model(torch.nn.Module):
771+
def __init__(self):
772+
super().__init__()
773+
774+
def forward(self, x: torch.Tensor, y: torch.Tensor):
775+
add_0 = torch.add(x, y)
776+
add_1 = torch.add(x, add_0)
777+
add_2 = torch.add(add_0, add_1)
778+
add_3 = torch.add(add_1, add_2)
779+
return add_3
780+
781+
model = Model()
782+
inputs = (torch.randn(4, 17), torch.randn(4, 17))
783+
for mem_algo in range(0, 2):
784+
graph_module = (
785+
compiler.export_to_executorch_gen_etrecord(
786+
model,
787+
inputs,
788+
opt_level=1,
789+
mem_algo=mem_algo,
790+
alloc_graph_input=False,
791+
alloc_graph_output=False,
792+
mem_alignment=37,
793+
)
794+
.exported_program()
795+
.graph_module
796+
)
797+
# Assert that all memory allocations are aligned to 32B start address
798+
for spec in collect_specs_from_nodes(
799+
graph_module.graph.nodes,
800+
ignore_graph_input=True,
801+
ignore_graph_output=True,
802+
):
803+
if spec and spec.mem_offset:
804+
self.assertEqual(spec.mem_offset % 37, 0)

backends/vulkan/runtime/graph/ops/glsl/matmul_naive_buffer.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/addmm_naive_buffer.glsl

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
#define PRECISION ${PRECISION}
1212

13+
$if HAS_BIAS:
14+
#define HAS_BIAS
15+
1316
#define T ${buffer_scalar_type(DTYPE)}
1417

1518
${define_required_extensions(DTYPE)}
@@ -19,13 +22,17 @@ layout(std430) buffer;
1922
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
2023
${layout_declare_tensor(B, "r", "t_mat1", DTYPE, "buffer")}
2124
${layout_declare_tensor(B, "r", "t_mat2", DTYPE, "buffer")}
25+
$if HAS_BIAS:
26+
${layout_declare_tensor(B, "r", "t_bias", DTYPE, "buffer")}
2227
${layout_declare_ubo(B, "ivec4", "out_sizes")}
2328
${layout_declare_ubo(B, "ivec4", "out_strides")}
2429
${layout_declare_ubo(B, "ivec4", "mat1_sizes")}
2530
${layout_declare_ubo(B, "ivec4", "mat1_strides")}
2631
${layout_declare_ubo(B, "ivec4", "mat2_sizes")}
2732
${layout_declare_ubo(B, "ivec4", "mat2_strides")}
2833
${layout_declare_ubo(B, "int", "out_numel")}
34+
$if HAS_BIAS:
35+
${layout_declare_ubo(B, "float", "alpha", "float", "beta")}
2936

3037
#include "indexing_utils.h"
3138

@@ -34,25 +41,25 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
3441
${layout_declare_spec_const(C, "int", "mat2_is_transposed", "0")}
3542

3643
void main() {
37-
const ivec4 out_bufix = ivec4(
44+
const ivec4 out_tidx = ivec4(
3845
gl_GlobalInvocationID.x,
3946
gl_GlobalInvocationID.y,
4047
gl_GlobalInvocationID.z % out_sizes.z,
4148
gl_GlobalInvocationID.z / out_sizes.z);
4249

43-
if (any(greaterThanEqual(out_bufix, out_sizes))) {
50+
if (any(greaterThanEqual(out_tidx, out_sizes))) {
4451
return;
4552
}
4653

4754
int mat1_bufi = tidx_to_bufi(
48-
ivec4(0, out_bufix.y, out_bufix.z, out_bufix.w), mat1_strides);
55+
ivec4(0, out_tidx.y, out_tidx.z, out_tidx.w), mat1_strides);
4956
int mat2_bufi;
5057
if (mat2_is_transposed > 0) {
5158
mat2_bufi = tidx_to_bufi(
52-
ivec4(0, out_bufix.x, 0, 0), mat2_strides);
59+
ivec4(0, out_tidx.x, 0, 0), mat2_strides);
5360
} else {
5461
mat2_bufi = tidx_to_bufi(
55-
ivec4(out_bufix.x, 0, out_bufix.z, out_bufix.w), mat2_strides);
62+
ivec4(out_tidx.x, 0, out_tidx.z, out_tidx.w), mat2_strides);
5663
}
5764

5865
int mat2_stride;
@@ -70,6 +77,10 @@ void main() {
7077
mat2_bufi += mat2_stride;
7178
}
7279

73-
const int out_bufi = tidx_to_bufi(out_bufix, out_strides);
80+
const int out_bufi = tidx_to_bufi(out_tidx, out_strides);
81+
#ifdef HAS_BIAS
82+
t_out[out_bufi] = T(alpha) * T(sum) + T(beta) * t_bias[out_tidx.x];
83+
#else
7484
t_out[out_bufi] = T(sum);
85+
#endif // HAS_BIAS
7586
}

0 commit comments

Comments
 (0)