Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Live reloading of shaders #937

Merged
merged 9 commits into from
Dec 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,10 @@ path = "examples/reflection/trait_reflection.rs"
name = "scene"
path = "examples/scene/scene.rs"

[[example]]
name = "hot_shader_reloading"
path = "examples/shader/hot_shader_reloading.rs"

[[example]]
name = "mesh_custom_attribute"
path = "examples/shader/mesh_custom_attribute.rs"
Expand Down
11 changes: 11 additions & 0 deletions assets/shaders/hot.frag
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
#version 450

layout(location = 0) out vec4 o_Target;

layout(set = 2, binding = 0) uniform MyMaterial_color {
vec4 color;
};

void main() {
o_Target = color * 0.5;
}
15 changes: 15 additions & 0 deletions assets/shaders/hot.vert
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#version 450

layout(location = 0) in vec3 Vertex_Position;

layout(set = 0, binding = 0) uniform Camera {
mat4 ViewProj;
};

layout(set = 1, binding = 0) uniform Transform {
mat4 Model;
};

void main() {
gl_Position = ViewProj * Model * vec4(Vertex_Position, 1.0);
}
4 changes: 4 additions & 0 deletions crates/bevy_render/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ use render_graph::{
RenderGraph,
};
use renderer::{AssetRenderResourceBindings, RenderResourceBindings};
use shader::ShaderLoader;
#[cfg(feature = "hdr")]
use texture::HdrTextureLoader;
#[cfg(feature = "png")]
Expand Down Expand Up @@ -87,6 +88,8 @@ impl Plugin for RenderPlugin {
app.init_asset_loader::<HdrTextureLoader>();
}

app.init_asset_loader::<ShaderLoader>();

if app.resources().get::<ClearColor>().is_none() {
app.resources_mut().insert(ClearColor::default());
}
Expand Down Expand Up @@ -134,6 +137,7 @@ impl Plugin for RenderPlugin {
camera::visible_entities_system,
)
// TODO: turn these "resource systems" into graph nodes and remove the RENDER_RESOURCE stage
.add_system_to_stage(stage::RENDER_RESOURCE, shader::shader_update_system)
.add_system_to_stage(stage::RENDER_RESOURCE, mesh::mesh_resource_provider_system)
.add_system_to_stage(stage::RENDER_RESOURCE, Texture::texture_resource_system)
.add_system_to_stage(
Expand Down
109 changes: 91 additions & 18 deletions crates/bevy_render/src/pipeline/pipeline_compiler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::{state_descriptors::PrimitiveTopology, IndexFormat, PipelineDescripto
use crate::{
pipeline::{BindType, InputStepMode, VertexBufferDescriptor},
renderer::RenderResourceContext,
shader::{Shader, ShaderSource},
shader::{Shader, ShaderError, ShaderSource},
};
use bevy_asset::{Assets, Handle};
use bevy_reflect::Reflect;
Expand Down Expand Up @@ -60,6 +60,7 @@ struct SpecializedPipeline {
#[derive(Debug, Default)]
pub struct PipelineCompiler {
specialized_shaders: HashMap<Handle<Shader>, Vec<SpecializedShader>>,
specialized_shader_pipelines: HashMap<Handle<Shader>, Vec<Handle<PipelineDescriptor>>>,
specialized_pipelines: HashMap<Handle<PipelineDescriptor>, Vec<SpecializedPipeline>>,
}

Expand All @@ -70,7 +71,7 @@ impl PipelineCompiler {
shaders: &mut Assets<Shader>,
shader_handle: &Handle<Shader>,
shader_specialization: &ShaderSpecialization,
) -> Handle<Shader> {
) -> Result<Handle<Shader>, ShaderError> {
let specialized_shaders = self
.specialized_shaders
.entry(shader_handle.clone_weak())
Expand All @@ -80,7 +81,7 @@ impl PipelineCompiler {

// don't produce new shader if the input source is already spirv
if let ShaderSource::Spirv(_) = shader.source {
return shader_handle.clone_weak();
return Ok(shader_handle.clone_weak());
}

if let Some(specialized_shader) =
Expand All @@ -91,7 +92,7 @@ impl PipelineCompiler {
})
{
// if shader has already been compiled with current configuration, use existing shader
specialized_shader.shader.clone_weak()
Ok(specialized_shader.shader.clone_weak())
} else {
// if no shader exists with the current configuration, create new shader and compile
let shader_def_vec = shader_specialization
Expand All @@ -100,14 +101,14 @@ impl PipelineCompiler {
.cloned()
.collect::<Vec<String>>();
let compiled_shader =
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec));
render_resource_context.get_specialized_shader(shader, Some(&shader_def_vec))?;
let specialized_handle = shaders.add(compiled_shader);
let weak_specialized_handle = specialized_handle.clone_weak();
specialized_shaders.push(SpecializedShader {
shader: specialized_handle,
specialization: shader_specialization.clone(),
});
weak_specialized_handle
Ok(weak_specialized_handle)
}
}

Expand Down Expand Up @@ -138,23 +139,31 @@ impl PipelineCompiler {
) -> Handle<PipelineDescriptor> {
let source_descriptor = pipelines.get(source_pipeline).unwrap();
let mut specialized_descriptor = source_descriptor.clone();
specialized_descriptor.shader_stages.vertex = self.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
);
let specialized_vertex_shader = self
.compile_shader(
render_resource_context,
shaders,
&specialized_descriptor.shader_stages.vertex,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_descriptor.shader_stages.vertex = specialized_vertex_shader.clone_weak();
let mut specialized_fragment_shader = None;
specialized_descriptor.shader_stages.fragment = specialized_descriptor
.shader_stages
.fragment
.as_ref()
.map(|fragment| {
self.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
let shader = self
.compile_shader(
render_resource_context,
shaders,
fragment,
&pipeline_specialization.shader_specialization,
)
.unwrap();
specialized_fragment_shader = Some(shader.clone_weak());
shader
});

let mut layout = render_resource_context.reflect_pipeline_layout(
Expand Down Expand Up @@ -244,6 +253,18 @@ impl PipelineCompiler {
&shaders,
);

// track specialized shader pipelines
self.specialized_shader_pipelines
.entry(specialized_vertex_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
if let Some(specialized_fragment_shader) = specialized_fragment_shader {
self.specialized_shader_pipelines
.entry(specialized_fragment_shader)
.or_insert_with(Default::default)
.push(source_pipeline.clone_weak());
}

let specialized_pipelines = self
.specialized_pipelines
.entry(source_pipeline.clone_weak())
Expand Down Expand Up @@ -282,4 +303,56 @@ impl PipelineCompiler {
})
.flatten()
}

/// Update specialized shaders and remove any related specialized
/// pipelines and assets.
pub fn update_shader(
&mut self,
shader: &Handle<Shader>,
pipelines: &mut Assets<PipelineDescriptor>,
shaders: &mut Assets<Shader>,
render_resource_context: &dyn RenderResourceContext,
) -> Result<(), ShaderError> {
if let Some(specialized_shaders) = self.specialized_shaders.get_mut(shader) {
for specialized_shader in specialized_shaders {
// Recompile specialized shader. If it fails, we bail immediately.
let shader_def_vec = specialized_shader
.specialization
.shader_defs
.iter()
.cloned()
.collect::<Vec<String>>();
let new_handle =
shaders.add(render_resource_context.get_specialized_shader(
shaders.get(shader).unwrap(),
Some(&shader_def_vec),
)?);

// Replace handle and remove old from assets.
let old_handle = std::mem::replace(&mut specialized_shader.shader, new_handle);
shaders.remove(&old_handle);

// Find source pipelines that use the old specialized
// shader, and remove from tracking.
if let Some(source_pipelines) =
self.specialized_shader_pipelines.remove(&old_handle)
{
// Remove all specialized pipelines from tracking
// and asset storage. They will be rebuilt on next
// draw.
for source_pipeline in source_pipelines {
if let Some(specialized_pipelines) =
self.specialized_pipelines.remove(&source_pipeline)
{
for p in specialized_pipelines {
pipelines.remove(p.pipeline);
}
}
}
}
}
}

Ok(())
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use super::RenderResourceContext;
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::Shader,
shader::{Shader, ShaderError},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Assets, Handle, HandleUntyped};
Expand Down Expand Up @@ -149,8 +149,12 @@ impl RenderResourceContext for HeadlessRenderResourceContext {
size
}

fn get_specialized_shader(&self, shader: &Shader, _macros: Option<&[String]>) -> Shader {
shader.clone()
fn get_specialized_shader(
&self,
shader: &Shader,
_macros: Option<&[String]>,
) -> Result<Shader, ShaderError> {
Ok(shader.clone())
}

fn remove_stale_bind_groups(&self) {}
Expand Down
8 changes: 6 additions & 2 deletions crates/bevy_render/src/renderer/render_resource_context.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::{
pipeline::{BindGroupDescriptorId, PipelineDescriptor, PipelineLayout},
renderer::{BindGroup, BufferId, BufferInfo, RenderResourceId, SamplerId, TextureId},
shader::{Shader, ShaderLayout, ShaderStages},
shader::{Shader, ShaderError, ShaderLayout, ShaderStages},
texture::{SamplerDescriptor, TextureDescriptor},
};
use bevy_asset::{Asset, Assets, Handle, HandleUntyped};
Expand Down Expand Up @@ -29,7 +29,11 @@ pub trait RenderResourceContext: Downcast + Send + Sync + 'static {
fn create_buffer_with_data(&self, buffer_info: BufferInfo, data: &[u8]) -> BufferId;
fn create_shader_module(&self, shader_handle: &Handle<Shader>, shaders: &Assets<Shader>);
fn create_shader_module_from_source(&self, shader_handle: &Handle<Shader>, shader: &Shader);
fn get_specialized_shader(&self, shader: &Shader, macros: Option<&[String]>) -> Shader;
fn get_specialized_shader(
&self,
shader: &Shader,
macros: Option<&[String]>,
) -> Result<Shader, ShaderError>;
fn remove_buffer(&self, buffer: BufferId);
fn remove_texture(&self, texture: TextureId);
fn remove_sampler(&self, sampler: SamplerId);
Expand Down
Loading