Skip to content

Commit 8194a6a

Browse files
committed
Add HLSL generator
1 parent 719dc53 commit 8194a6a

File tree

3 files changed

+1703
-0
lines changed

3 files changed

+1703
-0
lines changed

tools/hlsl_generator/gen.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
import json
2+
import io
3+
import os
4+
import re
5+
from enum import Enum
6+
from argparse import ArgumentParser
7+
from typing import NamedTuple
8+
from typing import Optional
9+
10+
head = """// Copyright (C) 2023 - DevSH Graphics Programming Sp. z O.O.
11+
// This file is part of the "Nabla Engine".
12+
// For conditions of distribution and use, see copyright notice in nabla.h
13+
#ifndef _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
14+
#define _NBL_BUILTIN_HLSL_SPIRV_INTRINSICS_CORE_INCLUDED_
15+
16+
#ifdef __HLSL_VERSION
17+
#include "spirv/unified1/spirv.hpp"
18+
#include "spirv/unified1/GLSL.std.450.h"
19+
#endif
20+
21+
#include "nbl/builtin/hlsl/type_traits.hlsl"
22+
23+
namespace nbl
24+
{
25+
namespace hlsl
26+
{
27+
#ifdef __HLSL_VERSION
28+
namespace spirv
29+
{
30+
31+
//! General Decls
32+
template<uint32_t StorageClass, typename T>
33+
using pointer_t = vk::SpirvOpaqueType<spv::OpTypePointer, vk::Literal< vk::integral_constant<uint32_t, StorageClass> >, T>;
34+
35+
// The holy operation that makes addrof possible
36+
template<uint32_t StorageClass, typename T>
37+
[[vk::ext_instruction(spv::OpCopyObject)]]
38+
pointer_t<StorageClass, T> copyObject([[vk::ext_reference]] T value);
39+
40+
//! Std 450 Extended set operations
41+
template<typename SquareMatrix>
42+
[[vk::ext_instruction(GLSLstd450MatrixInverse)]]
43+
SquareMatrix matrixInverse(NBL_CONST_REF_ARG(SquareMatrix) mat);
44+
45+
// Add specializations if you need to emit a `ext_capability` (this means that the instruction needs to forward through an `impl::` struct and so on)
46+
template<typename T, typename U>
47+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
48+
[[vk::ext_instruction(spv::OpBitcast)]]
49+
enable_if_t<is_spirv_type_v<T> && is_spirv_type_v<U>, T> bitcast(U);
50+
51+
template<typename T>
52+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
53+
[[vk::ext_instruction(spv::OpBitcast)]]
54+
uint64_t bitcast(pointer_t<spv::StorageClassPhysicalStorageBuffer,T>);
55+
56+
template<typename T>
57+
[[vk::ext_capability(spv::CapabilityPhysicalStorageBufferAddresses)]]
58+
[[vk::ext_instruction(spv::OpBitcast)]]
59+
pointer_t<spv::StorageClassPhysicalStorageBuffer,T> bitcast(uint64_t);
60+
61+
template<class T, class U>
62+
[[vk::ext_instruction(spv::OpBitcast)]]
63+
T bitcast(U);
64+
"""
65+
66+
foot = """}
67+
68+
#endif
69+
}
70+
}
71+
72+
#endif
73+
"""
74+
75+
def gen(grammer_path, metadata_path, output_path):
76+
grammer_raw = open(grammer_path, "r").read()
77+
grammer = json.loads(grammer_raw)
78+
del grammer_raw
79+
80+
metadata_raw = open(metadata_path, "r").read()
81+
metadata = json.loads(metadata_raw)
82+
del metadata_raw
83+
84+
output = open(output_path, "w", buffering=1024**2)
85+
86+
builtins = [x for x in grammer["operand_kinds"] if x["kind"] == "BuiltIn"][0]["enumerants"]
87+
execution_modes = [x for x in grammer["operand_kinds"] if x["kind"] == "ExecutionMode"][0]["enumerants"]
88+
group_operations = [x for x in grammer["operand_kinds"] if x["kind"] == "GroupOperation"][0]["enumerants"]
89+
90+
with output as writer:
91+
writer.write(head)
92+
93+
writer.write("\n//! Builtins\nnamespace builtin\n{")
94+
for name in metadata["builtins"].keys():
95+
# Validate
96+
builtin_exist = False
97+
for b in builtins:
98+
if b["enumerant"] == name: builtin_exist = True
99+
100+
if (builtin_exist):
101+
bm = metadata["builtins"][name]
102+
is_mutable = "const" in bm.keys() and bm["mutable"]
103+
writer.write("[[vk::ext_builtin_input(spv::BuiltIn" + name + ")]]\n")
104+
writer.write("static " + ("" if is_mutable else "const ") + bm["type"] + " " + name + ";\n")
105+
else:
106+
raise Exception("Invalid builtin " + name)
107+
writer.write("}\n")
108+
109+
writer.write("\n//! Execution Modes\nnamespace execution_mode\n{")
110+
for em in execution_modes:
111+
name = em["enumerant"]
112+
name_l = name[0].lower() + name[1:]
113+
writer.write("\n\tvoid " + name_l + "()\n\t{\n\t\tvk::ext_execution_mode(spv::ExecutionMode" + name + ");\n\t}\n")
114+
writer.write("}\n")
115+
116+
writer.write("\n//! Group Operations\nnamespace group_operation\n{\n")
117+
for go in group_operations:
118+
name = go["enumerant"]
119+
value = go["value"]
120+
writer.write("\tstatic const uint32_t " + name + " = " + str(value) + ";\n")
121+
writer.write("}\n")
122+
123+
writer.write("\n//! Instructions\n")
124+
for instruction in grammer["instructions"]:
125+
match instruction["class"]:
126+
case "Atomic":
127+
processInst(writer, instruction, InstOptions())
128+
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
129+
case "Memory":
130+
processInst(writer, instruction, InstOptions(shape=Shape.PTR_TEMPLATE))
131+
processInst(writer, instruction, InstOptions(shape=Shape.PSB_RT))
132+
case "Barrier" | "Bit":
133+
processInst(writer, instruction, InstOptions())
134+
case "Reserved":
135+
match instruction["opname"]:
136+
case "OpBeginInvocationInterlockEXT" | "OpEndInvocationInterlockEXT":
137+
processInst(writer, instruction, InstOptions())
138+
case "Non-Uniform":
139+
match instruction["opname"]:
140+
case "OpGroupNonUniformElect" | "OpGroupNonUniformAll" | "OpGroupNonUniformAny" | "OpGroupNonUniformAllEqual":
141+
processInst(writer, instruction, InstOptions(result_ty="bool"))
142+
case "OpGroupNonUniformBallot":
143+
processInst(writer, instruction, InstOptions(result_ty="uint32_t4",op_ty="bool"))
144+
case "OpGroupNonUniformInverseBallot" | "OpGroupNonUniformBallotBitExtract":
145+
processInst(writer, instruction, InstOptions(result_ty="bool",op_ty="uint32_t4"))
146+
case "OpGroupNonUniformBallotBitCount" | "OpGroupNonUniformBallotFindLSB" | "OpGroupNonUniformBallotFindMSB":
147+
processInst(writer, instruction, InstOptions(result_ty="uint32_t",op_ty="uint32_t4"))
148+
case _: processInst(writer, instruction, InstOptions())
149+
case _: continue # TODO
150+
151+
writer.write(foot)
152+
153+
class Shape(Enum):
154+
DEFAULT = 0,
155+
PTR_TEMPLATE = 1, # TODO: this is a DXC Workaround
156+
PSB_RT = 2, # PhysicalStorageBuffer Result Type
157+
158+
class InstOptions(NamedTuple):
159+
shape: Shape = Shape.DEFAULT
160+
result_ty: Optional[str] = None
161+
op_ty: Optional[str] = None
162+
163+
def processInst(writer: io.TextIOWrapper, instruction, options: InstOptions):
164+
templates = []
165+
caps = []
166+
conds = []
167+
op_name = instruction["opname"]
168+
fn_name = op_name[2].lower() + op_name[3:]
169+
result_types = []
170+
171+
if "capabilities" in instruction and len(instruction["capabilities"]) > 0:
172+
for cap in instruction["capabilities"]:
173+
if cap == "Shader" or cap == "Kernel": continue
174+
caps.append(cap)
175+
176+
if options.shape == Shape.PTR_TEMPLATE:
177+
templates.append("typename P")
178+
conds.append("is_spirv_type_v<P>")
179+
180+
# split upper case words
181+
matches = [(m.group(1), m.span(1)) for m in re.finditer(r'([A-Z])[A-Z][a-z]', fn_name)]
182+
183+
for m in matches:
184+
match m[0]:
185+
case "I":
186+
conds.append("(is_signed_v<T> || is_unsigned_v<T>)")
187+
break
188+
case "U":
189+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
190+
result_types = ["uint32_t", "uint64_t"]
191+
break
192+
case "S":
193+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
194+
result_types = ["int32_t", "int64_t"]
195+
break
196+
case "F":
197+
fn_name = fn_name[0:m[1][0]] + fn_name[m[1][1]:]
198+
result_types = ["float"]
199+
break
200+
201+
if "operands" in instruction:
202+
operands = instruction["operands"]
203+
if operands[0]["kind"] == "IdResultType":
204+
operands = operands[2:]
205+
if len(result_types) == 0:
206+
if options.result_ty == None:
207+
result_types = ["T"]
208+
else:
209+
result_types = [options.result_ty]
210+
else:
211+
assert len(result_types) == 0
212+
result_types = ["void"]
213+
214+
for rt in result_types:
215+
op_ty = "T"
216+
if options.op_ty != None:
217+
op_ty = options.op_ty
218+
elif rt != "void":
219+
op_ty = rt
220+
221+
if (not "typename T" in templates) and (rt == "T"):
222+
templates = ["typename T"] + templates
223+
224+
args = []
225+
for operand in operands:
226+
operand_name = operand["name"].strip("'") if "name" in operand else None
227+
operand_name = operand_name[0].lower() + operand_name[1:] if (operand_name != None) else ""
228+
match operand["kind"]:
229+
case "IdRef":
230+
match operand["name"]:
231+
case "'Pointer'":
232+
if options.shape == Shape.PTR_TEMPLATE:
233+
args.append("P " + operand_name)
234+
elif options.shape == Shape.PSB_RT:
235+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
236+
templates = ["typename T"] + templates
237+
args.append("pointer_t<spv::StorageClassPhysicalStorageBuffer, " + op_ty + "> " + operand_name)
238+
else:
239+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
240+
templates = ["typename T"] + templates
241+
args.append("[[vk::ext_reference]] " + op_ty + " " + operand_name)
242+
case "'Value'" | "'Object'" | "'Comparator'" | "'Base'" | "'Insert'":
243+
if (not "typename T" in templates) and (rt == "T" or op_ty == "T"):
244+
templates = ["typename T"] + templates
245+
args.append(op_ty + " " + operand_name)
246+
case "'Offset'" | "'Count'" | "'Id'" | "'Index'" | "'Mask'" | "'Delta'":
247+
args.append("uint32_t " + operand_name)
248+
case "'Predicate'": args.append("bool " + operand_name)
249+
case "'ClusterSize'":
250+
if "quantifier" in operand and operand["quantifier"] == "?": continue # TODO: overload
251+
else: return # TODO
252+
case _: return # TODO
253+
case "IdScope": args.append("uint32_t " + operand_name.lower() + "Scope")
254+
case "IdMemorySemantics": args.append(" uint32_t " + operand_name)
255+
case "GroupOperation": args.append("[[vk::ext_literal]] uint32_t " + operand_name)
256+
case "MemoryAccess":
257+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess"])
258+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t memoryAccess, [[vk::ext_literal]] uint32_t memoryAccessParam"])
259+
writeInst(writer, templates + ["uint32_t alignment"], caps, op_name, fn_name, conds, rt, args + ["[[vk::ext_literal]] uint32_t __aligned = /*Aligned*/0x00000002", "[[vk::ext_literal]] uint32_t __alignment = alignment"])
260+
case _: return # TODO
261+
262+
writeInst(writer, templates, caps, op_name, fn_name, conds, rt, args)
263+
264+
265+
def writeInst(writer: io.TextIOWrapper, templates, caps, op_name, fn_name, conds, result_type, args):
266+
if len(caps) > 0:
267+
for cap in caps:
268+
final_fn_name = fn_name
269+
if (len(caps) > 1): final_fn_name = fn_name + "_" + cap
270+
writeInstInner(writer, templates, cap, op_name, final_fn_name, conds, result_type, args)
271+
else:
272+
writeInstInner(writer, templates, None, op_name, fn_name, conds, result_type, args)
273+
274+
def writeInstInner(writer: io.TextIOWrapper, templates, cap, op_name, fn_name, conds, result_type, args):
275+
if len(templates) > 0:
276+
writer.write("template<" + ", ".join(templates) + ">\n")
277+
if (cap != None):
278+
writer.write("[[vk::ext_capability(spv::Capability" + cap + ")]]\n")
279+
writer.write("[[vk::ext_instruction(spv::" + op_name + ")]]\n")
280+
if len(conds) > 0:
281+
writer.write("enable_if_t<" + " && ".join(conds) + ", " + result_type + ">")
282+
else:
283+
writer.write(result_type)
284+
writer.write(" " + fn_name + "(" + ", ".join(args) + ");\n\n")
285+
286+
287+
if __name__ == "__main__":
288+
script_dir_path = os.path.abspath(os.path.dirname(__file__))
289+
290+
parser = ArgumentParser(description="Generate HLSL from SPIR-V instructions")
291+
parser.add_argument("output", type=str, help="HLSL output file")
292+
parser.add_argument("--grammer", required=False, type=str, help="Input SPIR-V grammer JSON file", default=os.path.join(script_dir_path, "../../include/spirv/unified1/spirv.core.grammar.json"))
293+
parser.add_argument("--metadata", required=False, type=str, help="Input SPIR-V Instructions/BuiltIns type mapping/attributes/etc", default=os.path.join(script_dir_path, "metadata.json"))
294+
args = parser.parse_args()
295+
296+
gen(args.grammer, args.metadata, args.output)
297+

tools/hlsl_generator/metadata.json

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
{
2+
"builtins": {
3+
"HelperInvocation": {
4+
"type": "bool",
5+
"mutable": true
6+
},
7+
"Position": {
8+
"type": "float32_t4"
9+
},
10+
"VertexIndex": {
11+
"type": "uint32_t",
12+
"mutable": true
13+
},
14+
"InstanceIndex": {
15+
"type": "uint32_t",
16+
"mutable": true
17+
},
18+
"NumWorkgroups": {
19+
"type": "uint32_t3",
20+
"mutable": true
21+
},
22+
"WorkgroupId": {
23+
"type": "uint32_t3",
24+
"mutable": true
25+
},
26+
"LocalInvocationId": {
27+
"type": "uint32_t3",
28+
"mutable": true
29+
},
30+
"GlobalInvocationId": {
31+
"type": "uint32_t3",
32+
"mutable": true
33+
},
34+
"LocalInvocationIndex": {
35+
"type": "uint32_t",
36+
"mutable": true
37+
},
38+
"SubgroupEqMask": {
39+
"type": "uint32_t4"
40+
},
41+
"SubgroupGeMask": {
42+
"type": "uint32_t4"
43+
},
44+
"SubgroupGtMask": {
45+
"type": "uint32_t4"
46+
},
47+
"SubgroupLeMask": {
48+
"type": "uint32_t4"
49+
},
50+
"SubgroupLtMask": {
51+
"type": "uint32_t4"
52+
},
53+
"SubgroupSize": {
54+
"type": "uint32_t"
55+
},
56+
"NumSubgroups": {
57+
"type": "uint32_t"
58+
},
59+
"SubgroupId": {
60+
"type": "uint32_t"
61+
},
62+
"SubgroupLocalInvocationId": {
63+
"type": "uint32_t"
64+
}
65+
}
66+
}

0 commit comments

Comments
 (0)