@@ -2,7 +2,7 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
22use crate :: {
33 pipeline:: { BindType , InputStepMode , VertexBufferDescriptor } ,
44 renderer:: RenderResourceContext ,
5- shader:: { Shader , ShaderSource } ,
5+ shader:: { Shader , ShaderError , ShaderSource } ,
66} ;
77use bevy_asset:: { Assets , Handle } ;
88use bevy_reflect:: Reflect ;
@@ -60,6 +60,7 @@ struct SpecializedPipeline {
6060#[ derive( Debug , Default ) ]
6161pub struct PipelineCompiler {
6262 specialized_shaders : HashMap < Handle < Shader > , Vec < SpecializedShader > > ,
63+ specialized_shader_pipelines : HashMap < Handle < Shader > , Vec < Handle < PipelineDescriptor > > > ,
6364 specialized_pipelines : HashMap < Handle < PipelineDescriptor > , Vec < SpecializedPipeline > > ,
6465}
6566
@@ -70,7 +71,7 @@ impl PipelineCompiler {
7071 shaders : & mut Assets < Shader > ,
7172 shader_handle : & Handle < Shader > ,
7273 shader_specialization : & ShaderSpecialization ,
73- ) -> Handle < Shader > {
74+ ) -> Result < Handle < Shader > , ShaderError > {
7475 let specialized_shaders = self
7576 . specialized_shaders
7677 . entry ( shader_handle. clone_weak ( ) )
@@ -80,7 +81,7 @@ impl PipelineCompiler {
8081
8182 // don't produce new shader if the input source is already spirv
8283 if let ShaderSource :: Spirv ( _) = shader. source {
83- return shader_handle. clone_weak ( ) ;
84+ return Ok ( shader_handle. clone_weak ( ) ) ;
8485 }
8586
8687 if let Some ( specialized_shader) =
@@ -91,7 +92,7 @@ impl PipelineCompiler {
9192 } )
9293 {
9394 // if shader has already been compiled with current configuration, use existing shader
94- specialized_shader. shader . clone_weak ( )
95+ Ok ( specialized_shader. shader . clone_weak ( ) )
9596 } else {
9697 // if no shader exists with the current configuration, create new shader and compile
9798 let shader_def_vec = shader_specialization
@@ -100,14 +101,14 @@ impl PipelineCompiler {
100101 . cloned ( )
101102 . collect :: < Vec < String > > ( ) ;
102103 let compiled_shader =
103- render_resource_context. get_specialized_shader ( shader, Some ( & shader_def_vec) ) ;
104+ render_resource_context. get_specialized_shader ( shader, Some ( & shader_def_vec) ) ? ;
104105 let specialized_handle = shaders. add ( compiled_shader) ;
105106 let weak_specialized_handle = specialized_handle. clone_weak ( ) ;
106107 specialized_shaders. push ( SpecializedShader {
107108 shader : specialized_handle,
108109 specialization : shader_specialization. clone ( ) ,
109110 } ) ;
110- weak_specialized_handle
111+ Ok ( weak_specialized_handle)
111112 }
112113 }
113114
@@ -138,23 +139,31 @@ impl PipelineCompiler {
138139 ) -> Handle < PipelineDescriptor > {
139140 let source_descriptor = pipelines. get ( source_pipeline) . unwrap ( ) ;
140141 let mut specialized_descriptor = source_descriptor. clone ( ) ;
141- specialized_descriptor. shader_stages . vertex = self . compile_shader (
142- render_resource_context,
143- shaders,
144- & specialized_descriptor. shader_stages . vertex ,
145- & pipeline_specialization. shader_specialization ,
146- ) ;
142+ let specialized_vertex_shader = self
143+ . compile_shader (
144+ render_resource_context,
145+ shaders,
146+ & specialized_descriptor. shader_stages . vertex ,
147+ & pipeline_specialization. shader_specialization ,
148+ )
149+ . unwrap ( ) ;
150+ specialized_descriptor. shader_stages . vertex = specialized_vertex_shader. clone_weak ( ) ;
151+ let mut specialized_fragment_shader = None ;
147152 specialized_descriptor. shader_stages . fragment = specialized_descriptor
148153 . shader_stages
149154 . fragment
150155 . as_ref ( )
151156 . map ( |fragment| {
152- self . compile_shader (
153- render_resource_context,
154- shaders,
155- fragment,
156- & pipeline_specialization. shader_specialization ,
157- )
157+ let shader = self
158+ . compile_shader (
159+ render_resource_context,
160+ shaders,
161+ fragment,
162+ & pipeline_specialization. shader_specialization ,
163+ )
164+ . unwrap ( ) ;
165+ specialized_fragment_shader = Some ( shader. clone_weak ( ) ) ;
166+ shader
158167 } ) ;
159168
160169 let mut layout = render_resource_context. reflect_pipeline_layout (
@@ -244,6 +253,18 @@ impl PipelineCompiler {
244253 & shaders,
245254 ) ;
246255
256+ // track specialized shader pipelines
257+ self . specialized_shader_pipelines
258+ . entry ( specialized_vertex_shader)
259+ . or_insert_with ( Default :: default)
260+ . push ( source_pipeline. clone_weak ( ) ) ;
261+ if let Some ( specialized_fragment_shader) = specialized_fragment_shader {
262+ self . specialized_shader_pipelines
263+ . entry ( specialized_fragment_shader)
264+ . or_insert_with ( Default :: default)
265+ . push ( source_pipeline. clone_weak ( ) ) ;
266+ }
267+
247268 let specialized_pipelines = self
248269 . specialized_pipelines
249270 . entry ( source_pipeline. clone_weak ( ) )
@@ -282,4 +303,56 @@ impl PipelineCompiler {
282303 } )
283304 . flatten ( )
284305 }
306+
307+ /// Update specialized shaders and remove any related specialized
308+ /// pipelines and assets.
309+ pub fn update_shader (
310+ & mut self ,
311+ shader : & Handle < Shader > ,
312+ pipelines : & mut Assets < PipelineDescriptor > ,
313+ shaders : & mut Assets < Shader > ,
314+ render_resource_context : & dyn RenderResourceContext ,
315+ ) -> Result < ( ) , ShaderError > {
316+ if let Some ( specialized_shaders) = self . specialized_shaders . get_mut ( shader) {
317+ for specialized_shader in specialized_shaders {
318+ // Recompile specialized shader. If it fails, we bail immediately.
319+ let shader_def_vec = specialized_shader
320+ . specialization
321+ . shader_defs
322+ . iter ( )
323+ . cloned ( )
324+ . collect :: < Vec < String > > ( ) ;
325+ let new_handle =
326+ shaders. add ( render_resource_context. get_specialized_shader (
327+ shaders. get ( shader) . unwrap ( ) ,
328+ Some ( & shader_def_vec) ,
329+ ) ?) ;
330+
331+ // Replace handle and remove old from assets.
332+ let old_handle = std:: mem:: replace ( & mut specialized_shader. shader , new_handle) ;
333+ shaders. remove ( & old_handle) ;
334+
335+ // Find source pipelines that use the old specialized
336+ // shader, and remove from tracking.
337+ if let Some ( source_pipelines) =
338+ self . specialized_shader_pipelines . remove ( & old_handle)
339+ {
340+ // Remove all specialized pipelines from tracking
341+ // and asset storage. They will be rebuilt on next
342+ // draw.
343+ for source_pipeline in source_pipelines {
344+ if let Some ( specialized_pipelines) =
345+ self . specialized_pipelines . remove ( & source_pipeline)
346+ {
347+ for p in specialized_pipelines {
348+ pipelines. remove ( p. pipeline ) ;
349+ }
350+ }
351+ }
352+ }
353+ }
354+ }
355+
356+ Ok ( ( ) )
357+ }
285358}
0 commit comments