Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lxvm committed Apr 5, 2023
1 parent f632411 commit 88eb405
Showing 1 changed file with 138 additions and 61 deletions.
199 changes: 138 additions & 61 deletions ext/BrillouinPlotlyJSExt/dispersion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,77 @@ const DEFAULT_PLOTLY_LAYOUT_DISPERSION = Layout(
plot_bgcolor=TRANSPARENT_COL[], paper_bgcolor=TRANSPARENT_COL[],
)


function plot(kpi::KPathInterpolant, traces::Vector{<:AbstractTrace},
layout::Layout = Layout();
ylims = nothing, ylabel = nothing, title = nothing,
config::PlotConfig = PlotConfig(responsive=true, displaylogo=false))

# merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without
# corrupting user input)
layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout)

# set default y-limits in layout, if not already set
yaxis = get!(layout.fields, :yaxis, attr())
if isnothing(ylims)
if !haskey(yaxis, :range)
yaxis[:range] = ylims
else
ylims = yaxis[:range] # grab what was already in `layout`
end
else
# overwrite if ylims was provided, regardless of what it is in `layout`
yaxis[:range] = ylims
end
yaxis[:title] = ylabel

# add title, if requested
if !isnothing(title)
if title isa String
layout[:title] = attr(text=title)
else
layout[:title] = title
end
end

# prepare to plot band diagram
Npaths = length(kpi.kpaths)
local_xs = cumdists.(cartesianize(kpi).kpaths)
local_xs_lengths = last.(local_xs)
total_xs_lengths = sum(local_xs_lengths)
spacing = total_xs_lengths / 30
rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1))
rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1))

# plot k-lines/labels
xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels]
xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels]
domain_start = 0.0 # subplot domain "start" point
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
# define xticks
for (lab_idx, (x_idx, lab)) in enumerate(labels)
xticks[path_idx][lab_idx] = local_x[x_idx]
xlabs[path_idx][lab_idx] = lab
end

# set subplot sizes and local xticks & xrange
sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name

layout[sym_xaxis] = copy(get(layout, :xaxis, attr()))
layout[sym_xaxis][:range] = [extrema(local_x)...]
layout[sym_xaxis][:tickvals] = xticks[path_idx]
layout[sym_xaxis][:ticktext] = xlabs[path_idx]

domain_end = domain_start + rel_xs_lengths[path_idx]
layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end]
domain_start = domain_end + rel_spacing
end
delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts...

return plot(traces, layout; config=config)
end


"""
plot(kpi::KPathInterpolant, bands, [layout]; kwargs...)
Expand Down Expand Up @@ -63,73 +134,33 @@ Alternatively, some simple settings can be set directly via keyword arguments (s
"""
function plot(kpi::KPathInterpolant, bands,
layout::Layout = Layout();
ylims = nothing, ylabel = "Energy", title = nothing,
ylims = default_dispersion_ylims(bands), ylabel = "Energy",
band_highlights::Union{Dict, Nothing} = nothing,
annotations::Union{Dict, Nothing} = nothing,
config::PlotConfig = PlotConfig(responsive=true, displaylogo=false))
kwargs...)
# check input
N = length(kpi)
if !all(band -> length(band) == N, bands)
throw(DimensionMismatch("mismatched dimensions of `kpi` and `bands`"))
end
# merge (and implicitly copy) `layout` (copy ensures we can mutate `layout` without
# corrupting user input)
layout = merge(DEFAULT_PLOTLY_LAYOUT_DISPERSION, layout)

# set default y-limits in layout, if not already set
yaxis = get!(layout.fields, :yaxis, attr())
if isnothing(ylims)
if !haskey(yaxis, :range)
ylims = default_dispersion_ylims(bands)
yaxis[:range] = ylims
else
ylims = yaxis[:range] # grab what was already in `layout`
end
else
# overwrite if ylims was provided, regardless of what it is in `layout`
yaxis[:range] = ylims
end
yaxis[:title] = ylabel

# add title, if requested
if !isnothing(title)
if title isa String
layout[:title] = attr(text=title)
else
layout[:title] = title
end
end

# prepare to plot band diagram
Npaths = length(kpi.kpaths)
local_xs = cumdists.(cartesianize(kpi).kpaths)
local_xs_lengths = last.(local_xs)
total_xs_lengths = sum(local_xs_lengths)
spacing = total_xs_lengths / 30
rel_xs_lengths = local_xs_lengths./(total_xs_lengths+spacing*(Npaths-1))
rel_spacing = spacing/(total_xs_lengths+spacing*(Npaths-1))

# plot bands and k-lines/labels
tbands = Vector{GenericTrace{Dict{Symbol,Any}}}()
xticks = [Vector{Float64}(undef, length(labels)) for labels in kpi.labels]
xlabs = [Vector{Symbol}(undef, length(labels)) for labels in kpi.labels]
start_idx = 1
domain_start = 0.0 # subplot domain "start" point
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
stop_idx = start_idx+length(local_x)-1
# plot bands
for (i, band) in enumerate(bands)
line = something(_get_value_if_in_ranges(band_highlights, i),
attr(color=BAND_COL[], width=3)) # default
line = something(_get_value_if_in_ranges(band_highlights, i),
attr(color=BAND_COL[], width=3)) # default
push!(tbands,
PlotlyJS.scatter(x=local_x, y=band[start_idx:stop_idx],
hoverinfo="y", mode="lines", line=line, xaxis="x$path_idx", yaxis="y"))
end
# define xticks
for (lab_idx, (x_idx, lab)) in enumerate(labels)
xticks[path_idx][lab_idx] = local_x[x_idx]
xlabs[path_idx][lab_idx] = lab
end

# place any high-symmetry point annotations
if annotations !== nothing
for (lab, positions_and_strs) in annotations
Expand Down Expand Up @@ -157,25 +188,12 @@ function plot(kpi::KPathInterpolant, bands,
end
end
end

# set subplot sizes and local xticks & xrange
sym_xaxis = Symbol("xaxis$path_idx") # subplot xaxis name

layout[sym_xaxis] = copy(get(layout, :xaxis, attr()))
layout[sym_xaxis][:range] = [extrema(local_x)...]
layout[sym_xaxis][:tickvals] = xticks[path_idx]
layout[sym_xaxis][:ticktext] = xlabs[path_idx]

domain_end = domain_start + rel_xs_lengths[path_idx]
layout[Symbol(sym_xaxis, "_domain")] = [domain_start, domain_end]
domain_start = domain_end + rel_spacing

# prepare for next iteration
start_idx = stop_idx + 1
end
delete!(layout.fields, :xaxis) # get rid of unused xaxis in layout; causes artifacts...

return plot(tbands, layout; config=config)
return plot(kpi, tbands, layout; ylims=ylims, ylabel=ylabel, kwargs...)
end
# `bands` can also be supplied as a matrix (w/ distinct bands in distinct columns)
function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real},
Expand All @@ -188,7 +206,7 @@ function plot(kpi::KPathInterpolant, bands::AbstractMatrix{<:Real},
end

function default_dispersion_ylims(bands)
ylims = [mapfoldl(minimum, min, bands, init=Inf),
ylims = [mapfoldl(minimum, min, bands, init=Inf),
mapfoldl(maximum, max, bands, init=-Inf)]
δ = (ylims[2]-ylims[1])/30
if isapprox(ylims[1], 0, atol=1e-6)
Expand All @@ -205,4 +223,63 @@ function _get_value_if_in_ranges(d::Dict, i::Integer)
end
return nothing
end
_get_value_if_in_ranges(::Nothing, ::Integer) = nothing
_get_value_if_in_ranges(::Nothing, ::Integer) = nothing

# ---------------------------------------------------------------------------------------- #

"""
plot(kpi::KPathInterpolant, ωs, fields, [layout]; kwargs...)
Plot a dispersion heatmap for provided `fields` and **k**-path interpolant `kpi`.
`fields` must be an iterable of real matrices (e.g., a `Vector{Matrix{Float64}}`),
with the first iteration running over distinct fields to overlay.
Note that the size of each iterant of `fields` must equal `(length(kpi), length(ωs))`.
## Keyword arguments `kwargs`
- `ylims`: y-axis limits (default: `extrema(ωs)`)
- `ylabel`: y-axis label (default: "Energy")
- `title`: plot title (default: `nothing`); can be a `String` or an `attr` dictionary of
PlotlyJS properties
- `opacity`: transparency of colormap (default: `1`); useful to set a value less than one
when overlaying multiple `fields`
- `colorscale`: an iteratable of PlotlyJS color scales (default: ["YlGnBu"])
- `reversescale`: boolean that reverses color scale (default: false)
"""
function plot(kpi::KPathInterpolant, ωs, fields,
layout::Layout = Layout();
ylims = extrema(ωs), ylabel = "Energy",
opacity=1, colorscale=["YlGnBu"], reversescale=false,
kwargs...)
# check input
N = length(kpi); M = length(ωs)
if !all(field -> size(field) == (N,M), fields)
throw(DimensionMismatch("mismatched dimensions of `kpi` with `ωs` and `fields`"))
end

# prepare to plot band diagram
local_xs = cumdists.(cartesianize(kpi).kpaths)

# plot bands and k-lines/labels
heatmaps = Vector{GenericTrace{Dict{Symbol,Any}}}()
start_idx = 1
for (path_idx, (local_x, labels)) in enumerate(zip(local_xs, kpi.labels))
stop_idx = start_idx+length(local_x)-1
# plot fields
for (i, field) in enumerate(fields)
push!(heatmaps,
PlotlyJS.heatmap(x=local_x, y=ωs, z=transpose(field[start_idx:stop_idx,:]),
xaxis="x$path_idx", yaxis="y", opacity=opacity,
colorscale=colorscale[mod(i-1,length(colorscale))+1], reversescale=reversescale))
end
start_idx = stop_idx + 1
end

plot(kpi, heatmaps, layout; ylims=ylims, ylabel=ylabel, kwargs...)
end

0 comments on commit 88eb405

Please sign in to comment.