Skip to content

Commit ffbc453

Browse files
add storage_texture option to as_bind_group macro
Changes: - Add storage_texture option to as_bind_group macro - Use it to generate the bind group layout for the compute shader example
1 parent 746361b commit ffbc453

File tree

4 files changed

+149
-15
lines changed

4 files changed

+149
-15
lines changed

crates/bevy_render/macros/src/as_bind_group.rs

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use syn::{
1111

1212
const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
1313
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
14+
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
1415
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
1516
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
1617
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
@@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
1920
enum BindingType {
2021
Uniform,
2122
Texture,
23+
StorageTexture,
2224
Sampler,
2325
Storage,
2426
}
@@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
133135
BindingType::Uniform
134136
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
135137
BindingType::Texture
138+
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
139+
BindingType::StorageTexture
136140
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
137141
BindingType::Sampler
138142
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
255259
}
256260
});
257261
}
262+
BindingType::StorageTexture => {
263+
let StorageTextureAttrs {
264+
dimension,
265+
image_format,
266+
access,
267+
visibility,
268+
} = get_storage_texture_binding_attr(nested_meta_items)?;
269+
270+
let visibility =
271+
visibility.hygienic_quote(&quote! { #render_path::render_resource });
272+
273+
let fallback_image = get_fallback_image(&render_path, dimension);
274+
275+
binding_impls.push(quote! {
276+
( #binding_index,
277+
#render_path::render_resource::OwnedBindingResource::TextureView({
278+
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
279+
if let Some(handle) = handle {
280+
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
281+
} else {
282+
#fallback_image.texture_view.clone()
283+
}
284+
})
285+
)
286+
});
287+
288+
binding_layouts.push(quote! {
289+
#render_path::render_resource::BindGroupLayoutEntry {
290+
binding: #binding_index,
291+
visibility: #visibility,
292+
ty: #render_path::render_resource::BindingType::StorageTexture {
293+
access: #render_path::render_resource::StorageTextureAccess::#access,
294+
format: #render_path::render_resource::TextureFormat::#image_format,
295+
view_dimension: #render_path::render_resource::#dimension,
296+
},
297+
count: None,
298+
}
299+
});
300+
}
258301
BindingType::Texture => {
259302
let TextureAttrs {
260303
dimension,
@@ -585,6 +628,10 @@ impl ShaderStageVisibility {
585628
fn vertex_fragment() -> Self {
586629
Self::Flags(VisibilityFlags::vertex_fragment())
587630
}
631+
632+
fn compute() -> Self {
633+
Self::Flags(VisibilityFlags::compute())
634+
}
588635
}
589636

590637
impl VisibilityFlags {
@@ -595,6 +642,13 @@ impl VisibilityFlags {
595642
..Default::default()
596643
}
597644
}
645+
646+
fn compute() -> Self {
647+
Self {
648+
compute: true,
649+
..Default::default()
650+
}
651+
}
598652
}
599653

600654
impl ShaderStageVisibility {
@@ -741,7 +795,72 @@ impl Default for TextureAttrs {
741795
}
742796
}
743797

798+
struct StorageTextureAttrs {
799+
dimension: BindingTextureDimension,
800+
// Parsing of the image_format parameter is deferred to the type checker,
801+
// which will error if the format is not member of the TextureFormat enum.
802+
image_format: proc_macro2::TokenStream,
803+
// Parsing of the access parameter is deferred to the type checker,
804+
// which will error if the access is not member of the StorageTextureAccess enum.
805+
access: proc_macro2::TokenStream,
806+
visibility: ShaderStageVisibility,
807+
}
808+
809+
impl Default for StorageTextureAttrs {
810+
fn default() -> Self {
811+
Self {
812+
dimension: Default::default(),
813+
image_format: quote! { Rgba8Unorm },
814+
access: quote! { ReadWrite },
815+
visibility: ShaderStageVisibility::compute(),
816+
}
817+
}
818+
}
819+
820+
fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
821+
let mut storage_texture_attrs = StorageTextureAttrs::default();
822+
823+
for meta in metas {
824+
use syn::Meta::{List, NameValue};
825+
match meta {
826+
// Parse #[storage_texture(0, dimension = "...")].
827+
NameValue(m) if m.path == DIMENSION => {
828+
let value = get_lit_str(DIMENSION, &m.value)?;
829+
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
830+
}
831+
// Parse #[storage_texture(0, format = ...))].
832+
NameValue(m) if m.path == IMAGE_FORMAT => {
833+
storage_texture_attrs.image_format = m.value.into_token_stream();
834+
}
835+
// Parse #[storage_texture(0, access = ...))].
836+
NameValue(m) if m.path == ACCESS => {
837+
storage_texture_attrs.access = m.value.into_token_stream();
838+
}
839+
// Parse #[storage_texture(0, visibility(...))].
840+
List(m) if m.path == VISIBILITY => {
841+
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
842+
}
843+
NameValue(m) => {
844+
return Err(Error::new_spanned(
845+
m.path,
846+
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
847+
));
848+
}
849+
_ => {
850+
return Err(Error::new_spanned(
851+
meta,
852+
"Not a name value pair: `foo = \"...\"`",
853+
));
854+
}
855+
}
856+
}
857+
858+
Ok(storage_texture_attrs)
859+
}
860+
744861
const DIMENSION: Symbol = Symbol("dimension");
862+
const IMAGE_FORMAT: Symbol = Symbol("image_format");
863+
const ACCESS: Symbol = Symbol("access");
745864
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
746865
const FILTERABLE: Symbol = Symbol("filterable");
747866
const MULTISAMPLED: Symbol = Symbol("multisampled");

crates/bevy_render/macros/src/lib.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {
5151

5252
#[proc_macro_derive(
5353
AsBindGroup,
54-
attributes(uniform, texture, sampler, bind_group_data, storage)
54+
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
5555
)]
5656
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
5757
let input = parse_macro_input!(input as DeriveInput);

crates/bevy_render/src/render_resource/bind_group.rs

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ impl Deref for BindGroup {
8787
/// values: Vec<f32>,
8888
/// #[storage(4, read_only, buffer)]
8989
/// buffer: Buffer,
90+
/// #[storage_texture(5)]
91+
/// storage_texture: Handle<Image>,
9092
/// }
9193
/// ```
9294
///
@@ -97,6 +99,7 @@ impl Deref for BindGroup {
9799
/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
98100
/// @group(2) @binding(2) var color_sampler: sampler;
99101
/// @group(2) @binding(3) var<storage> values: array<f32>;
102+
/// @group(2) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
100103
/// ```
101104
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
102105
/// are generally bound to group 2.
@@ -123,6 +126,19 @@ impl Deref for BindGroup {
123126
/// | `multisampled` = ... | `true`, `false` | `false` |
124127
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
125128
///
129+
/// * `storage_texture(BINDING_INDEX, arguments)`
130+
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
131+
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
132+
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
133+
/// [`None`], the [`FallbackImage`] resource will be used instead.
134+
///
135+
/// | Arguments | Values | Default |
136+
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
137+
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
138+
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
139+
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
140+
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
141+
///
126142
/// * `sampler(BINDING_INDEX, arguments)`
127143
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU
128144
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,

examples/shader/compute_shader_game_of_life.rs

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,16 @@
66
use bevy::{
77
prelude::*,
88
render::{
9-
extract_resource::{ExtractResource, ExtractResourcePlugin},
9+
extract_resource::ExtractResourcePlugin,
1010
render_asset::RenderAssets,
1111
render_graph::{self, RenderGraph},
12-
render_resource::{binding_types::texture_storage_2d, *},
12+
render_resource::*,
1313
renderer::{RenderContext, RenderDevice},
1414
Render, RenderApp, RenderSet,
1515
},
1616
window::WindowPlugin,
1717
};
18+
use bevy_internal::render::extract_resource::ExtractResource;
1819
use std::borrow::Cow;
1920

2021
const SIZE: (u32, u32) = (1280, 720);
@@ -63,7 +64,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
6364
});
6465
commands.spawn(Camera2dBundle::default());
6566

66-
commands.insert_resource(GameOfLifeImage(image));
67+
commands.insert_resource(GameOfLifeImage { texture: image });
6768
}
6869

6970
pub struct GameOfLifeComputePlugin;
@@ -93,8 +94,11 @@ impl Plugin for GameOfLifeComputePlugin {
9394
}
9495
}
9596

96-
#[derive(Resource, Clone, Deref, ExtractResource)]
97-
struct GameOfLifeImage(Handle<Image>);
97+
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
98+
struct GameOfLifeImage {
99+
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
100+
texture: Handle<Image>,
101+
}
98102

99103
#[derive(Resource)]
100104
struct GameOfLifeImageBindGroup(BindGroup);
@@ -106,7 +110,7 @@ fn prepare_bind_group(
106110
game_of_life_image: Res<GameOfLifeImage>,
107111
render_device: Res<RenderDevice>,
108112
) {
109-
let view = gpu_images.get(&game_of_life_image.0).unwrap();
113+
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
110114
let bind_group = render_device.create_bind_group(
111115
None,
112116
&pipeline.texture_bind_group_layout,
@@ -124,13 +128,8 @@ pub struct GameOfLifePipeline {
124128

125129
impl FromWorld for GameOfLifePipeline {
126130
fn from_world(world: &mut World) -> Self {
127-
let texture_bind_group_layout = world.resource::<RenderDevice>().create_bind_group_layout(
128-
None,
129-
&BindGroupLayoutEntries::single(
130-
ShaderStages::COMPUTE,
131-
texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite),
132-
),
133-
);
131+
let render_device = world.resource::<RenderDevice>();
132+
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
134133
let shader = world
135134
.resource::<AssetServer>()
136135
.load("shaders/game_of_life.wgsl");
@@ -217,7 +216,7 @@ impl render_graph::Node for GameOfLifeNode {
217216
.command_encoder()
218217
.begin_compute_pass(&ComputePassDescriptor::default());
219218

220-
pass.set_bind_group(0, texture_bind_group, &[]);
219+
pass.set_bind_group(0, &texture_bind_group, &[]);
221220

222221
// select the pipeline based on the current state
223222
match self.state {

0 commit comments

Comments
 (0)