Skip to content

add storage_texture option to as_bind_group macro #9943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions crates/bevy_render/macros/src/as_bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use syn::{

const UNIFORM_ATTRIBUTE_NAME: Symbol = Symbol("uniform");
const TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("texture");
const STORAGE_TEXTURE_ATTRIBUTE_NAME: Symbol = Symbol("storage_texture");
const SAMPLER_ATTRIBUTE_NAME: Symbol = Symbol("sampler");
const STORAGE_ATTRIBUTE_NAME: Symbol = Symbol("storage");
const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
Expand All @@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
enum BindingType {
Uniform,
Texture,
StorageTexture,
Sampler,
Storage,
}
Expand Down Expand Up @@ -133,6 +135,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
BindingType::Uniform
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
BindingType::Texture
} else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
BindingType::StorageTexture
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
BindingType::Sampler
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
Expand Down Expand Up @@ -255,6 +259,45 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
}
});
}
BindingType::StorageTexture => {
let StorageTextureAttrs {
dimension,
image_format,
access,
visibility,
} = get_storage_texture_binding_attr(nested_meta_items)?;

let visibility =
visibility.hygienic_quote(&quote! { #render_path::render_resource });

let fallback_image = get_fallback_image(&render_path, dimension);

binding_impls.push(quote! {
( #binding_index,
#render_path::render_resource::OwnedBindingResource::TextureView({
let handle: Option<&#asset_path::Handle<#render_path::texture::Image>> = (&self.#field_name).into();
if let Some(handle) = handle {
images.get(handle).ok_or_else(|| #render_path::render_resource::AsBindGroupError::RetryNextUpdate)?.texture_view.clone()
} else {
#fallback_image.texture_view.clone()
}
})
)
});

binding_layouts.push(quote! {
#render_path::render_resource::BindGroupLayoutEntry {
binding: #binding_index,
visibility: #visibility,
ty: #render_path::render_resource::BindingType::StorageTexture {
access: #render_path::render_resource::StorageTextureAccess::#access,
format: #render_path::render_resource::TextureFormat::#image_format,
view_dimension: #render_path::render_resource::#dimension,
},
count: None,
}
});
}
BindingType::Texture => {
let TextureAttrs {
dimension,
Expand Down Expand Up @@ -585,6 +628,10 @@ impl ShaderStageVisibility {
fn vertex_fragment() -> Self {
Self::Flags(VisibilityFlags::vertex_fragment())
}

fn compute() -> Self {
Self::Flags(VisibilityFlags::compute())
}
}

impl VisibilityFlags {
Expand All @@ -595,6 +642,13 @@ impl VisibilityFlags {
..Default::default()
}
}

fn compute() -> Self {
Self {
compute: true,
..Default::default()
}
}
}

impl ShaderStageVisibility {
Expand Down Expand Up @@ -741,7 +795,72 @@ impl Default for TextureAttrs {
}
}

struct StorageTextureAttrs {
dimension: BindingTextureDimension,
// Parsing of the image_format parameter is deferred to the type checker,
// which will error if the format is not member of the TextureFormat enum.
image_format: proc_macro2::TokenStream,
// Parsing of the access parameter is deferred to the type checker,
// which will error if the access is not member of the StorageTextureAccess enum.
access: proc_macro2::TokenStream,
visibility: ShaderStageVisibility,
}

impl Default for StorageTextureAttrs {
fn default() -> Self {
Self {
dimension: Default::default(),
image_format: quote! { Rgba8Unorm },
access: quote! { ReadWrite },
visibility: ShaderStageVisibility::compute(),
}
}
}

fn get_storage_texture_binding_attr(metas: Vec<Meta>) -> Result<StorageTextureAttrs> {
let mut storage_texture_attrs = StorageTextureAttrs::default();

for meta in metas {
use syn::Meta::{List, NameValue};
match meta {
// Parse #[storage_texture(0, dimension = "...")].
NameValue(m) if m.path == DIMENSION => {
let value = get_lit_str(DIMENSION, &m.value)?;
storage_texture_attrs.dimension = get_texture_dimension_value(value)?;
}
// Parse #[storage_texture(0, format = ...))].
NameValue(m) if m.path == IMAGE_FORMAT => {
storage_texture_attrs.image_format = m.value.into_token_stream();
}
// Parse #[storage_texture(0, access = ...))].
NameValue(m) if m.path == ACCESS => {
storage_texture_attrs.access = m.value.into_token_stream();
}
// Parse #[storage_texture(0, visibility(...))].
List(m) if m.path == VISIBILITY => {
storage_texture_attrs.visibility = get_visibility_flag_value(&m)?;
}
NameValue(m) => {
return Err(Error::new_spanned(
m.path,
"Not a valid name. Available attributes: `dimension`, `image_format`, `access`.",
));
}
_ => {
return Err(Error::new_spanned(
meta,
"Not a name value pair: `foo = \"...\"`",
));
}
}
}

Ok(storage_texture_attrs)
}

const DIMENSION: Symbol = Symbol("dimension");
const IMAGE_FORMAT: Symbol = Symbol("image_format");
const ACCESS: Symbol = Symbol("access");
const SAMPLE_TYPE: Symbol = Symbol("sample_type");
const FILTERABLE: Symbol = Symbol("filterable");
const MULTISAMPLED: Symbol = Symbol("multisampled");
Expand Down
2 changes: 1 addition & 1 deletion crates/bevy_render/macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub fn derive_extract_component(input: TokenStream) -> TokenStream {

#[proc_macro_derive(
AsBindGroup,
attributes(uniform, texture, sampler, bind_group_data, storage)
attributes(uniform, storage_texture, texture, sampler, bind_group_data, storage)
)]
pub fn derive_as_bind_group(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
Expand Down
16 changes: 16 additions & 0 deletions crates/bevy_render/src/render_resource/bind_group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ impl Deref for BindGroup {
/// values: Vec<f32>,
/// #[storage(4, read_only, buffer)]
/// buffer: Buffer,
/// #[storage_texture(5)]
/// storage_texture: Handle<Image>,
/// }
/// ```
///
Expand All @@ -97,6 +99,7 @@ impl Deref for BindGroup {
/// @group(2) @binding(1) var color_texture: texture_2d<f32>;
/// @group(2) @binding(2) var color_sampler: sampler;
/// @group(2) @binding(3) var<storage> values: array<f32>;
/// @group(2) @binding(5) var storage_texture: texture_storage_2d<rgba8unorm, read_write>;
/// ```
/// Note that the "group" index is determined by the usage context. It is not defined in [`AsBindGroup`]. For example, in Bevy material bind groups
/// are generally bound to group 2.
Expand All @@ -123,6 +126,19 @@ impl Deref for BindGroup {
/// | `multisampled` = ... | `true`, `false` | `false` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `vertex`, `fragment` |
///
/// * `storage_texture(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Texture`](crate::render_resource::Texture)
/// GPU resource, which will be bound as a storage texture in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
/// most fields should be a [`Handle<Image>`](bevy_asset::Handle) or [`Option<Handle<Image>>`]. If the value of an [`Option<Handle<Image>>`] is
/// [`None`], the [`FallbackImage`] resource will be used instead.
///
/// | Arguments | Values | Default |
/// |------------------------|--------------------------------------------------------------------------------------------|---------------|
/// | `dimension` = "..." | `"1d"`, `"2d"`, `"2d_array"`, `"3d"`, `"cube"`, `"cube_array"` | `"2d"` |
/// | `image_format` = ... | any member of [`TextureFormat`](crate::render_resource::TextureFormat) | `Rgba8Unorm` |
/// | `access` = ... | any member of [`StorageTextureAccess`](crate::render_resource::StorageTextureAccess) | `ReadWrite` |
/// | `visibility(...)` | `all`, `none`, or a list-combination of `vertex`, `fragment`, `compute` | `compute` |
///
/// * `sampler(BINDING_INDEX, arguments)`
/// * This field's [`Handle<Image>`](bevy_asset::Handle) will be used to look up the matching [`Sampler`] GPU
/// resource, which will be bound as a sampler in shaders. The field will be assumed to implement [`Into<Option<Handle<Image>>>`]. In practice,
Expand Down
22 changes: 10 additions & 12 deletions examples/shader/compute_shader_game_of_life.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use bevy::{
render_asset::RenderAssetPersistencePolicy,
render_asset::RenderAssets,
render_graph::{self, RenderGraph},
render_resource::{binding_types::texture_storage_2d, *},
render_resource::*,
renderer::{RenderContext, RenderDevice},
Render, RenderApp, RenderSet,
},
Expand Down Expand Up @@ -65,7 +65,7 @@ fn setup(mut commands: Commands, mut images: ResMut<Assets<Image>>) {
});
commands.spawn(Camera2dBundle::default());

commands.insert_resource(GameOfLifeImage(image));
commands.insert_resource(GameOfLifeImage { texture: image });
}

pub struct GameOfLifeComputePlugin;
Expand Down Expand Up @@ -95,8 +95,11 @@ impl Plugin for GameOfLifeComputePlugin {
}
}

#[derive(Resource, Clone, Deref, ExtractResource)]
struct GameOfLifeImage(Handle<Image>);
#[derive(Resource, Clone, Deref, ExtractResource, AsBindGroup)]
struct GameOfLifeImage {
#[storage_texture(0, image_format = Rgba8Unorm, access = ReadWrite)]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea to use it here!

texture: Handle<Image>,
}

#[derive(Resource)]
struct GameOfLifeImageBindGroup(BindGroup);
Expand All @@ -108,7 +111,7 @@ fn prepare_bind_group(
game_of_life_image: Res<GameOfLifeImage>,
render_device: Res<RenderDevice>,
) {
let view = gpu_images.get(&game_of_life_image.0).unwrap();
let view = gpu_images.get(&game_of_life_image.texture).unwrap();
let bind_group = render_device.create_bind_group(
None,
&pipeline.texture_bind_group_layout,
Expand All @@ -126,13 +129,8 @@ pub struct GameOfLifePipeline {

impl FromWorld for GameOfLifePipeline {
fn from_world(world: &mut World) -> Self {
let texture_bind_group_layout = world.resource::<RenderDevice>().create_bind_group_layout(
None,
&BindGroupLayoutEntries::single(
ShaderStages::COMPUTE,
texture_storage_2d(TextureFormat::Rgba8Unorm, StorageTextureAccess::ReadWrite),
),
);
let render_device = world.resource::<RenderDevice>();
let texture_bind_group_layout = GameOfLifeImage::bind_group_layout(render_device);
let shader = world
.resource::<AssetServer>()
.load("shaders/game_of_life.wgsl");
Expand Down