@@ -11,6 +11,7 @@ use syn::{
11
11
12
12
const UNIFORM_ATTRIBUTE_NAME : Symbol = Symbol ( "uniform" ) ;
13
13
const TEXTURE_ATTRIBUTE_NAME : Symbol = Symbol ( "texture" ) ;
14
+ const STORAGE_TEXTURE_ATTRIBUTE_NAME : Symbol = Symbol ( "storage_texture" ) ;
14
15
const SAMPLER_ATTRIBUTE_NAME : Symbol = Symbol ( "sampler" ) ;
15
16
const STORAGE_ATTRIBUTE_NAME : Symbol = Symbol ( "storage" ) ;
16
17
const BIND_GROUP_DATA_ATTRIBUTE_NAME : Symbol = Symbol ( "bind_group_data" ) ;
@@ -19,6 +20,7 @@ const BIND_GROUP_DATA_ATTRIBUTE_NAME: Symbol = Symbol("bind_group_data");
19
20
enum BindingType {
20
21
Uniform ,
21
22
Texture ,
23
+ StorageTexture ,
22
24
Sampler ,
23
25
Storage ,
24
26
}
@@ -139,6 +141,8 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
139
141
BindingType :: Uniform
140
142
} else if attr_ident == TEXTURE_ATTRIBUTE_NAME {
141
143
BindingType :: Texture
144
+ } else if attr_ident == STORAGE_TEXTURE_ATTRIBUTE_NAME {
145
+ BindingType :: StorageTexture
142
146
} else if attr_ident == SAMPLER_ATTRIBUTE_NAME {
143
147
BindingType :: Sampler
144
148
} else if attr_ident == STORAGE_ATTRIBUTE_NAME {
@@ -262,6 +266,43 @@ pub fn derive_as_bind_group(ast: syn::DeriveInput) -> Result<TokenStream> {
262
266
}
263
267
} ) ;
264
268
}
269
+ BindingType :: StorageTexture => {
270
+ let StorageTextureAttrs {
271
+ dimension,
272
+ image_format,
273
+ access,
274
+ visibility,
275
+ } = get_storage_texture_binding_attr ( nested_meta_items) ?;
276
+
277
+ let visibility =
278
+ visibility. hygienic_quote ( & quote ! { #render_path:: render_resource } ) ;
279
+
280
+ let fallback_image = get_fallback_image ( & render_path, dimension) ;
281
+
282
+ binding_impls. push ( quote ! {
283
+ #render_path:: render_resource:: OwnedBindingResource :: TextureView ( {
284
+ let handle: Option <& #asset_path:: Handle <#render_path:: texture:: Image >> = ( & self . #field_name) . into( ) ;
285
+ if let Some ( handle) = handle {
286
+ images. get( handle) . ok_or_else( || #render_path:: render_resource:: AsBindGroupError :: RetryNextUpdate ) ?. texture_view. clone( )
287
+ } else {
288
+ #fallback_image. texture_view. clone( )
289
+ }
290
+ } )
291
+ } ) ;
292
+
293
+ binding_layouts. push ( quote ! {
294
+ #render_path:: render_resource:: BindGroupLayoutEntry {
295
+ binding: #binding_index,
296
+ visibility: #visibility,
297
+ ty: #render_path:: render_resource:: BindingType :: StorageTexture {
298
+ access: #render_path:: render_resource:: StorageTextureAccess :: #access,
299
+ format: #render_path:: render_resource:: TextureFormat :: #image_format,
300
+ view_dimension: #render_path:: render_resource:: #dimension,
301
+ } ,
302
+ count: None ,
303
+ }
304
+ } ) ;
305
+ }
265
306
BindingType :: Texture => {
266
307
let TextureAttrs {
267
308
dimension,
@@ -593,6 +634,10 @@ impl ShaderStageVisibility {
593
634
fn vertex_fragment ( ) -> Self {
594
635
Self :: Flags ( VisibilityFlags :: vertex_fragment ( ) )
595
636
}
637
+
638
+ fn compute ( ) -> Self {
639
+ Self :: Flags ( VisibilityFlags :: compute ( ) )
640
+ }
596
641
}
597
642
598
643
impl VisibilityFlags {
@@ -603,6 +648,13 @@ impl VisibilityFlags {
603
648
..Default :: default ( )
604
649
}
605
650
}
651
+
652
+ fn compute ( ) -> Self {
653
+ Self {
654
+ compute : true ,
655
+ ..Default :: default ( )
656
+ }
657
+ }
606
658
}
607
659
608
660
impl ShaderStageVisibility {
@@ -749,7 +801,72 @@ impl Default for TextureAttrs {
749
801
}
750
802
}
751
803
804
+ struct StorageTextureAttrs {
805
+ dimension : BindingTextureDimension ,
806
+ // Parsing of the image_format parameter is deferred to the type checker,
807
+ // which will error if the format is not member of the TextureFormat enum.
808
+ image_format : proc_macro2:: TokenStream ,
809
+ // Parsing of the access parameter is deferred to the type checker,
810
+ // which will error if the format is not member of the TextureFormat enum.
811
+ access : proc_macro2:: TokenStream ,
812
+ visibility : ShaderStageVisibility ,
813
+ }
814
+
815
+ impl Default for StorageTextureAttrs {
816
+ fn default ( ) -> Self {
817
+ Self {
818
+ dimension : Default :: default ( ) ,
819
+ image_format : quote ! { Rgba8Unorm } ,
820
+ access : quote ! { ReadWrite } ,
821
+ visibility : ShaderStageVisibility :: compute ( ) ,
822
+ }
823
+ }
824
+ }
825
+
826
+ fn get_storage_texture_binding_attr ( metas : Vec < Meta > ) -> Result < StorageTextureAttrs > {
827
+ let mut storage_texture_attrs = StorageTextureAttrs :: default ( ) ;
828
+
829
+ for meta in metas {
830
+ use syn:: Meta :: { List , NameValue } ;
831
+ match meta {
832
+ // Parse #[storage_texture(0, dimension = "...")].
833
+ NameValue ( m) if m. path == DIMENSION => {
834
+ let value = get_lit_str ( DIMENSION , & m. value ) ?;
835
+ storage_texture_attrs. dimension = get_texture_dimension_value ( value) ?;
836
+ }
837
+ // Parse #[storage_texture(0, format = ...))].
838
+ NameValue ( m) if m. path == IMAGE_FORMAT => {
839
+ storage_texture_attrs. image_format = m. value . into_token_stream ( ) ;
840
+ }
841
+ // Parse #[storage_texture(0, access = ...))].
842
+ NameValue ( m) if m. path == ACCESS => {
843
+ storage_texture_attrs. access = m. value . into_token_stream ( ) ;
844
+ }
845
+ // Parse #[storage_texture(0, visibility(...))].
846
+ List ( m) if m. path == VISIBILITY => {
847
+ storage_texture_attrs. visibility = get_visibility_flag_value ( & m) ?;
848
+ }
849
+ NameValue ( m) => {
850
+ return Err ( Error :: new_spanned (
851
+ m. path ,
852
+ "Not a valid name. Available attributes: `dimension`, `image_format`, `access`." ,
853
+ ) ) ;
854
+ }
855
+ _ => {
856
+ return Err ( Error :: new_spanned (
857
+ meta,
858
+ "Not a name value pair: `foo = \" ...\" `" ,
859
+ ) ) ;
860
+ }
861
+ }
862
+ }
863
+
864
+ Ok ( storage_texture_attrs)
865
+ }
866
+
752
867
const DIMENSION : Symbol = Symbol ( "dimension" ) ;
868
+ const IMAGE_FORMAT : Symbol = Symbol ( "image_format" ) ;
869
+ const ACCESS : Symbol = Symbol ( "access" ) ;
753
870
const SAMPLE_TYPE : Symbol = Symbol ( "sample_type" ) ;
754
871
const FILTERABLE : Symbol = Symbol ( "filterable" ) ;
755
872
const MULTISAMPLED : Symbol = Symbol ( "multisampled" ) ;
0 commit comments