Skip to content

Commit 3cc6b7f

Browse files
add storage_texture option to as_bind_group macro
Changes: - Add storage_texture option to as_bind_group macro - Use it to generate the bind group layout for the compute shader example
1 parent b6ead2b commit 3cc6b7f

File tree

3 files changed

+127
-21
lines changed

3 files changed

+127
-21
lines changed

crates/bevy_render/macros/src/as_bind_group.rs

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::{
1111

1212
const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
1313
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
14+
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
1415
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
1516
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
1617
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
@@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
1920
enum BindingType {
2021
Uniform,
2122
Texture,
23+
StorageTexture,
2224
Sampler,
2325
Storage,
2426
}
@@ -139,6 +141,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
139141
BindingType::Uniform
140142
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
141143
BindingType::Texture
144+
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
145+
BindingType::StorageTexture
142146
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
143147
BindingType::Sampler
144148
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@@ -262,6 +266,43 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
262266
}
263267
});
264268
}
269+
BindingType::StorageTexture => {
270+
let StorageTextureAttrs {
271+
dimension,
272+
image_format,
273+
access,
274+
visibility,
275+
} = get_storage_texture_binding_attr(nested_meta_items)?;
276+
277+
let visibility =
278+
visibility.hygienic_quote(&quote! { #render_path::render_resource });
279+
280+
let fallback_image = get_fallback_image(&render_path, dimension);
281+
282+
binding_impls.push(quote! {
283+
#render_path::render_resource::OwnedBindingResource::TextureView({
284+
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
285+
if let Some(handle) = handle {
286+
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
287+
} else {
288+
#fallback_image.texture_view.clone()
289+
}
290+
})
291+
});
292+
293+
binding_layouts.push(quote! {
294+
#render_path::render_resource::BindGroupLayoutEntry {
295+
binding: #binding_index,
296+
visibility: #visibility,
297+
ty: #render_path::render_resource::BindingType::StorageTexture {
298+
access: #render_path::render_resource::StorageTextureAccess::#access,
299+
format: #render_path::render_resource::TextureFormat::#image_format,
300+
view_dimension: #render_path::render_resource::#dimension,
301+
},
302+
count: None,
303+
}
304+
});
305+
}
265306
BindingType::Texture => {
266307
let TextureAttrs {
267308
dimension,
@@ -593,6 +634,10 @@ impl ShaderStageVisibility {
593634
fn vertex_fragment() -> Self {
594635
Self::Flags(VisibilityFlags::vertex_fragment())
595636
}
637+
638+
fn compute() -> Self {
639+
Self::Flags(VisibilityFlags::compute())
640+
}
596641
}
597642

598643
impl VisibilityFlags {
@@ -603,6 +648,13 @@ impl VisibilityFlags {
603648
..Default::default()
604649
}
605650
}
651+
652+
fn compute() -> Self {
653+
Self {
654+
compute: true,
655+
..Default::default()
656+
}
657+
}
606658
}
607659

608660
impl ShaderStageVisibility {
@@ -749,7 +801,72 @@ impl Default for TextureAttrs {
749801
}
750802
}
751803

804+
struct StorageTextureAttrs {
805+
dimension: BindingTextureDimension,
806+
// Parsing of the image_format parameter is deferred to the type checker,
807+
// which will error if the format is not member of the TextureFormat enum.
808+
image_format: proc_macro2::TokenStream,
809+
// Parsing of the access parameter is deferred to the type checker,
810+
// which will error if the format is not member of the TextureFormat enum.
811+
access: proc_macro2::TokenStream,
812+
visibility: ShaderStageVisibility,
813+
}
814+
815+
impl Default for StorageTextureAttrs {
816+
fn default() -> Self {
817+
Self {
818+
dimension: Default::default(),
819+
image_format: quote! { Rgba8Unorm },
820+
access: quote! { ReadWrite },
821+
visibility: ShaderStageVisibility::compute(),
822+
}
823+
}
824+
}
825+
826+
fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
827+
let mut storage_texture_attrs = StorageTextureAttrs::default();
828+
829+
for meta in metas {
830+
use syn::Meta::{List, NameValue};
831+
match meta {
832+
// Parse #[storage_texture(0, dimension = "...")].
833+
NameValue(m) if m.path == DIMENSION => {
834+
let value = get_lit_str(DIMENSION, &m.value)?;
835+
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
836+
}
837+
// Parse #[storage_texture(0, format = ...))].
838+
NameValue(m) if m.path == IMAGE_FORMAT => {
839+
storage_texture_attrs.image_format = m.value.into_token_stream();
840+
}
841+
// Parse #[storage_texture(0, access = ...))].
842+
NameValue(m) if m.path == ACCESS => {
843+
storage_texture_attrs.access = m.value.into_token_stream();
844+
}
845+
// Parse #[storage_texture(0, visibility(...))].
846+
List(m) if m.path == VISIBILITY => {
847+
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
848+
}
849+
NameValue(m) => {
850+
return Err(Error::new_spanned(
851+
m.path,
852+
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
853+
));
854+
}
855+
_ => {
856+
return Err(Error::new_spanned(
857+
meta,
858+
"Not a name value pair: `foo = \"...\"`",
859+
));
860+
}
861+
}
862+
}
863+
864+
Ok(storage_texture_attrs)
865+
}
866+
752867
const DIMENSION: Symbol = Symbol("dimension");
868+
const IMAGE_FORMAT: Symbol = Symbol("image_format");
869+
const ACCESS: Symbol = Symbol("access");
753870
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
754871
const FILTERABLE: Symbol = Symbol("filterable");
755872
const MULTISAMPLED: Symbol = Symbol("multisampled");

crates/bevy_render/macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {
5151

5252
#[proc_macro_derive(
5353
AsBindGroup,
54-
attributes(uniform, texture, sampler, bind_group_data, storage)
54+
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
5555
)]
5656
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
5757
let input = parse_macro_input!(input as DeriveInput);

examples/shader/compute_shader_game_of_life.rs

Lines changed: 9 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
6363
});
6464
commands.spawn(Camera2dBundle::default());
6565

66-
commands.insert_resource(GameOfLifeImage(image));
66+
commands.insert_resource(GameOfLifeImage { texture: image });
6767
}
6868

6969
pub struct GameOfLifeComputePlugin;
@@ -93,8 +93,11 @@ impl Plugin for GameOfLifeComputePlugin {
9393
}
9494
}
9595

96-
#[derive(Resource, Clone, Deref, ExtractResource)]
97-
struct GameOfLifeImage(Handle<Image>);
96+
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
97+
struct GameOfLifeImage {
98+
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
99+
texture: Handle<Image>,
100+
}
98101

99102
#[derive(Resource)]
100103
struct GameOfLifeImageBindGroup(BindGroup);
@@ -106,7 +109,7 @@ fn prepare_bind_group(
106109
game_of_life_image: Res<GameOfLifeImage>,
107110
render_device: Res<RenderDevice>,
108111
) {
109-
let view = gpu_images.get(&game_of_life_image.0).unwrap();
112+
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
110113
let bind_group = render_device.create_bind_group(&BindGroupDescriptor {
111114
label: None,
112115
layout: &pipeline.texture_bind_group_layout,
@@ -127,22 +130,8 @@ pub struct GameOfLifePipeline {
127130

128131
impl FromWorld for GameOfLifePipeline {
129132
fn from_world(world: &mut World) -> Self {
130-
let texture_bind_group_layout =
131-
world
132-
.resource::<RenderDevice>()
133-
.create_bind_group_layout(&BindGroupLayoutDescriptor {
134-
label: None,
135-
entries: &[BindGroupLayoutEntry {
136-
binding: 0,
137-
visibility: ShaderStages::COMPUTE,
138-
ty: BindingType::StorageTexture {
139-
access: StorageTextureAccess::ReadWrite,
140-
format: TextureFormat::Rgba8Unorm,
141-
view_dimension: TextureViewDimension::D2,
142-
},
143-
count: None,
144-
}],
145-
});
133+
let render_device = world.resource::<RenderDevice>();
134+
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
146135
let shader = world
147136
.resource::<AssetServer>()
148137
.load("shaders/game_of_life.wgsl");

0 commit comments

Comments
 (0)