Skip to content

Commit 81e510b

Browse files
add storage_texture option to as_bind_group macro
Changes: - Add storage_texture option to as_bind_group macro - Use it to generate the bind group layout for the compute shader example
1 parent b6ead2b commit 81e510b

File tree

4 files changed

+157
-31
lines changed

4 files changed

+157
-31
lines changed

crates/bevy_render/macros/src/as_bind_group.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::{
1111

1212
const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
1313
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
14+
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
1415
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
1516
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
1617
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
@@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
1920
enum BindingType {
2021
Uniform,
2122
Texture,
23+
StorageTexture,
2224
Sampler,
2325
Storage,
2426
}
@@ -139,6 +141,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
139141
BindingType::Uniform
140142
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
141143
BindingType::Texture
144+
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
145+
BindingType::StorageTexture
142146
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
143147
BindingType::Sampler
144148
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@@ -262,6 +266,43 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
262266
}
263267
});
264268
}
269+
BindingType::StorageTexture => {
270+
let StorageTextureAttrs {
271+
dimension,
272+
image_format,
273+
access,
274+
visibility,
275+
} = get_storage_texture_binding_attr(nested_meta_items)?;
276+
277+
let visibility =
278+
visibility.hygienic_quote(&quote! { #render_path::render_resource });
279+
280+
let fallback_image = get_fallback_image(&render_path, dimension);
281+
282+
binding_impls.push(quote! {
283+
#render_path::render_resource::OwnedBindingResource::TextureView({
284+
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
285+
if let Some(handle) = handle {
286+
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
287+
} else {
288+
#fallback_image.texture_view.clone()
289+
}
290+
})
291+
});
292+
293+
binding_layouts.push(quote! {
294+
#render_path::render_resource::BindGroupLayoutEntry {
295+
binding: #binding_index,
296+
visibility: #visibility,
297+
ty: #render_path::render_resource::BindingType::StorageTexture {
298+
access: #render_path::render_resource::StorageTextureAccess::#access,
299+
format: #render_path::render_resource::TextureFormat::#image_format,
300+
view_dimension: #render_path::render_resource::#dimension,
301+
},
302+
count: None,
303+
}
304+
});
305+
}
265306
BindingType::Texture => {
266307
let TextureAttrs {
267308
dimension,
@@ -593,6 +634,10 @@ impl ShaderStageVisibility {
593634
fn vertex_fragment() -> Self {
594635
Self::Flags(VisibilityFlags::vertex_fragment())
595636
}
637+
638+
fn compute() -> Self {
639+
Self::Flags(VisibilityFlags::compute())
640+
}
596641
}
597642

598643
impl VisibilityFlags {
@@ -603,6 +648,13 @@ impl VisibilityFlags {
603648
..Default::default()
604649
}
605650
}
651+
652+
fn compute() -> Self {
653+
Self {
654+
compute: true,
655+
..Default::default()
656+
}
657+
}
606658
}
607659

608660
impl ShaderStageVisibility {
@@ -749,7 +801,72 @@ impl Default for TextureAttrs {
749801
}
750802
}
751803

804+
struct StorageTextureAttrs {
805+
dimension: BindingTextureDimension,
806+
// Parsing of the image_format parameter is deferred to the type checker,
807+
// which will error if the format is not member of the TextureFormat enum.
808+
image_format: proc_macro2::TokenStream,
809+
// Parsing of the access parameter is deferred to the type checker,
810+
// which will error if the access is not member of the StorageTextureAccess enum.
811+
access: proc_macro2::TokenStream,
812+
visibility: ShaderStageVisibility,
813+
}
814+
815+
impl Default for StorageTextureAttrs {
816+
fn default() -> Self {
817+
Self {
818+
dimension: Default::default(),
819+
image_format: quote! { Rgba8Unorm },
820+
access: quote! { ReadWrite },
821+
visibility: ShaderStageVisibility::compute(),
822+
}
823+
}
824+
}
825+
826+
fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
827+
let mut storage_texture_attrs = StorageTextureAttrs::default();
828+
829+
for meta in metas {
830+
use syn::Meta::{List, NameValue};
831+
match meta {
832+
// Parse #[storage_texture(0, dimension = "...")].
833+
NameValue(m) if m.path == DIMENSION => {
834+
let value = get_lit_str(DIMENSION, &m.value)?;
835+
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
836+
}
837+
// Parse #[storage_texture(0, format = ...))].
838+
NameValue(m) if m.path == IMAGE_FORMAT => {
839+
storage_texture_attrs.image_format = m.value.into_token_stream();
840+
}
841+
// Parse #[storage_texture(0, access = ...))].
842+
NameValue(m) if m.path == ACCESS => {
843+
storage_texture_attrs.access = m.value.into_token_stream();
844+
}
845+
// Parse #[storage_texture(0, visibility(...))].
846+
List(m) if m.path == VISIBILITY => {
847+
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
848+
}
849+
NameValue(m) => {
850+
return Err(Error::new_spanned(
851+
m.path,
852+
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
853+
));
854+
}
855+
_ => {
856+
return Err(Error::new_spanned(
857+
meta,
858+
"Not a name value pair: `foo = \"...\"`",
859+
));
860+
}
861+
}
862+
}
863+
864+
Ok(storage_texture_attrs)
865+
}
866+
752867
const DIMENSION: Symbol = Symbol("dimension");
868+
const IMAGE_FORMAT: Symbol = Symbol("image_format");
869+
const ACCESS: Symbol = Symbol("access");
753870
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
754871
const FILTERABLE: Symbol = Symbol("filterable");
755872
const MULTISAMPLED: Symbol = Symbol("multisampled");

crates/bevy_render/macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {
5151

5252
#[proc_macro_derive(
5353
AsBindGroup,
54-
attributes(uniform, texture, sampler, bind_group_data, storage)
54+
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
5555
)]
5656
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
5757
let input = parse_macro_input!(input as DeriveInput);

crates/bevy_render/src/render_resource/bind_group.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,8 @@ impl Deref for BindGroup {
8686
/// values: Vec<f32>,
8787
/// #[storage(4, read_only, buffer)]
8888
/// buffer: Buffer,
89+
/// #[storage_texture(5)]
90+
/// storage_texture: Handle<Image>,
8991
/// }
9092
/// ```
9193
///
@@ -96,6 +98,7 @@ impl Deref for BindGroup {
9698
/// @group(1) @binding(1) var color_texture: texture_2d<f32>;
9799
/// @group(1) @binding(2) var color_sampler: sampler;
98100
/// @group(1) @binding(3) var<storage> values: array<f32>;
101+
/// @group(1) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
99102
/// ```
100103
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
101104
/// are generally bound to group 1.
@@ -122,6 +125,19 @@ impl Deref for BindGroup {
122125
/// | `multisampled` = ... | `true`, `false` | `false` |
123126
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
124127
///
128+
/// * `storage_texture(BINDING_INDEX, arguments)`
129+
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
130+
/// GPU resource, which will be bound as a storage_texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
131+
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
132+
/// [`None`], the [`FallbackImage`] resource will be used instead.
133+
///
134+
/// | Arguments | Values | Default |
135+
/// |------------------------|-------------------------------------------------------------------------|----------------------|
136+
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
137+
/// | `image_format` = ... | any member of `wgpu::TextureFormat` | `RgbaUniform` |
138+
/// | `access` = ... | any member of `wgpu::StorageTextureAccess` | `ReadWrite` |
139+
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
140+
///
125141
/// * `sampler(BINDING_INDEX, arguments)`
126142
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`](crate::render_resource::Sampler) GPU
127143
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,

examples/shader/compute_shader_game_of_life.rs

Lines changed: 23 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use bevy::{
1111
render_graph::{self, RenderGraph},
1212
render_resource::*,
1313
renderer::{RenderContext, RenderDevice},
14+
texture::FallbackImage,
1415
Render, RenderApp, RenderSet,
1516
},
1617
window::WindowPlugin,
@@ -63,7 +64,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
6364
});
6465
commands.spawn(Camera2dBundle::default());
6566

66-
commands.insert_resource(GameOfLifeImage(image));
67+
commands.insert_resource(GameOfLifeImage { texture: image });
6768
}
6869

6970
pub struct GameOfLifeComputePlugin;
@@ -93,28 +94,34 @@ impl Plugin for GameOfLifeComputePlugin {
9394
}
9495
}
9596

96-
#[derive(Resource, Clone, Deref, ExtractResource)]
97-
struct GameOfLifeImage(Handle<Image>);
97+
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
98+
struct GameOfLifeImage {
99+
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
100+
texture: Handle<Image>,
101+
}
98102

99103
#[derive(Resource)]
100-
struct GameOfLifeImageBindGroup(BindGroup);
104+
struct GameOfLifeImageBindGroup(PreparedBindGroup<()>);
101105

102106
fn prepare_bind_group(
103107
mut commands: Commands,
104108
pipeline: Res<GameOfLifePipeline>,
105109
gpu_images: Res<RenderAssets<Image>>,
106110
game_of_life_image: Res<GameOfLifeImage>,
107111
render_device: Res<RenderDevice>,
112+
fallback_image: Res<FallbackImage>,
108113
) {
109-
let view = gpu_images.get(&game_of_life_image.0).unwrap();
110-
let bind_group = render_device.create_bind_group(&BindGroupDescriptor {
111-
label: None,
112-
layout: &pipeline.texture_bind_group_layout,
113-
entries: &[BindGroupEntry {
114-
binding: 0,
115-
resource: BindingResource::TextureView(&view.texture_view),
116-
}],
117-
});
114+
// When `AsBindGroup` is derived, `as_bind_group` will never return
115+
// an error, so we can safely unwrap.
116+
let bind_group = game_of_life_image
117+
.as_bind_group(
118+
&pipeline.texture_bind_group_layout,
119+
&render_device,
120+
&gpu_images,
121+
&fallback_image,
122+
)
123+
.ok()
124+
.unwrap();
118125
commands.insert_resource(GameOfLifeImageBindGroup(bind_group));
119126
}
120127

@@ -127,22 +134,8 @@ pub struct GameOfLifePipeline {
127134

128135
impl FromWorld for GameOfLifePipeline {
129136
fn from_world(world: &mut World) -> Self {
130-
let texture_bind_group_layout =
131-
world
132-
.resource::<RenderDevice>()
133-
.create_bind_group_layout(&BindGroupLayoutDescriptor {
134-
label: None,
135-
entries: &[BindGroupLayoutEntry {
136-
binding: 0,
137-
visibility: ShaderStages::COMPUTE,
138-
ty: BindingType::StorageTexture {
139-
access: StorageTextureAccess::ReadWrite,
140-
format: TextureFormat::Rgba8Unorm,
141-
view_dimension: TextureViewDimension::D2,
142-
},
143-
count: None,
144-
}],
145-
});
137+
let render_device = world.resource::<RenderDevice>();
138+
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
146139
let shader = world
147140
.resource::<AssetServer>()
148141
.load("shaders/game_of_life.wgsl");
@@ -229,7 +222,7 @@ impl render_graph::Node for GameOfLifeNode {
229222
.command_encoder()
230223
.begin_compute_pass(&ComputePassDescriptor::default());
231224

232-
pass.set_bind_group(0, texture_bind_group, &[]);
225+
pass.set_bind_group(0, &texture_bind_group.bind_group, &[]);
233226

234227
// select the pipeline based on the current state
235228
match self.state {

0 commit comments

Comments
 (0)