# Copyright 2025 Onera
# This file is part of the Noda package
# SPDX-License-Identifier: GPL-3.0-or-later
"""Generate plots."""
from copy import deepcopy
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import matplotlib as mpl
mpl.rcParams['axes.formatter.limits'] = (-2, 3)
mpl.rcParams['axes.grid'] = True
mpl.rcParams['grid.alpha'] = 0.5
[docs]
def process_xaxis_unit(zunit):
"""
Convert x-axis unit label to plotting-ready form and make multiplier.
Parameters
----------
zunit : str
Unit in which distance is to be plotted on x axis.
Returns
-------
zunit : str
Unit in which distance is to be plotted on x axis.
zmult : float
Factor by which distance (x axis) will be multiplied before plotting.
"""
zmult = None
if zunit == 'nm':
zmult = 1e9
elif zunit == 'um':
zmult = 1e6
zunit = r'$\mu$m'
return zunit, zmult
[docs]
def process_yaxis_unit(varname):
"""
Pre-process info to be used on y axis.
Convert y-axis variable name to plotting-ready form and make y-axis unit
label and multiplier.
Parameters
----------
zunit : str
Unit in which distance is to be plotted.
varname : str
Name of variable to be plotted on y axis.
Returns
-------
varname : str
Name of variable to be plotted on y axis.
varunit : float
Unit of y-axis variable.
varmult : float
Factor by which y-axis variable will be multiplied before plotting.
"""
varmult = 1
varunit = ''
if varname.startswith('J'):
varunit = 'mol m$^{-2}$ s$^{-1}$'
varname = rf'J^\mathrm{{{varname[1:]}}}'
elif varname == 'mu':
varunit = 'kJ/mol'
varname = r'\mu'
varmult = 1e-3
elif varname == 'deformation':
varunit = '-'
varname = r'\varepsilon'
elif varname == 'v':
varunit = 'm/s'
elif varname == 'c':
varunit = 'mol/m3'
return varname, varunit, varmult
[docs]
def plot_profile_single(z, var, varname, title, zunit='um'):
"""
Plot variable as a function of distance.
Parameters
----------
z : 1D array
Positions to be used on x axis. Can be either node positions (size
`nz`) or midpoint positions (size `nz` - 1) depending on the y-axis
variable.
var : 1D array or dict of 1D arrays
Quantity to be plotted on y-axis.
varname : str
Name of y-axis variable.
title : str
Plot title.
zunit : str, optional
x-axis unit. The default is 'um'.
Returns
-------
fig, ax : matplotlib figure and axis
"""
zunit, zmult = process_xaxis_unit(zunit)
varname, varunit, varmult = process_yaxis_unit(varname)
ylabel = None
fig, ax = plt.subplots()
if isinstance(var, np.ndarray):
ax.plot(z*zmult, var*varmult)
ylabel = f'${varname}$ ({varunit})'
elif isinstance(var, dict):
for k in var:
color = 'k' if k == 'Va' else None
ax.plot(z*zmult, var[k]*varmult, c=color, label=k)
ylabel = f'${varname}_i$ ({varunit})'
ax.legend()
ax.set_xlabel(f'$z$ ({zunit})')
ax.set_ylabel(ylabel)
ax.ticklabel_format(style='sci', scilimits=(-2, 4))
ax.set_title(title)
return fig, ax
[docs]
def add_profile(z, var, varname, zunit='um', extra_legend=None):
"""
Add set of profiles to existing plot.
If extra_legend arg is provided, a legend entry is added with a black
symbol, which applies to all added profiles.
Parameters
----------
z : 1D array
Positions to be used on x axis. Can be either node positions (size
`nz`) or midpoint positions (size `nz` - 1) depending on the y-axis
variable.
var : dict of 1D arrays
Quantities to be added to y-axis.
varname : str
Name of y-axis variable.
zunit : str, optional
x-axis unit. The default is 'um'.
extra_legend : str, optional
Common legend to all added profiles. The default is None.
"""
zunit, zmult = process_xaxis_unit(zunit)
varname, _, varmult = process_yaxis_unit(varname)
ax = plt.gca()
ax.set_prop_cycle(None)
for k in var:
ax.plot(z*zmult, var[k]*varmult, '--k', mfc='none')
if extra_legend:
ax.plot([], [], '--k', mfc='none', label=extra_legend)
ax.legend()
[docs]
def plot_profile_quartet(result, title, suptitle,
zunit='um', ylim=None, exclude_dep=True):
"""
Plot composition and flux profiles.
Produce a 2 x 2 grid of subplots with the following variables:
* Atom fraction
* Relative difference between simulated and equilibrium vacancy site
fraction
* Flux in the lattice-fixed frame
* Volume fraction of pores.
Parameters
----------
result : :class:`results.UnitResult`
Simulation results at given time step.
title : str
Axis title.
suptitle : str
Plot title.
zunit : str, optional
x-axis unit. The default is 'um'.
ylim : dict, optional
Lower and upper limits of the y-axis of the four subplots. Four keys
are recognized: 'x', 'y0', 'J', 'fp'. The values are to be given as a
tuple or list of two floats. The default is None.
exclude_dep : bool, optional
Exclude the dependent constituent from the atom fraction subplots. The
default is True.
Returns
-------
fig : matplotlib.figure.Figure
Figure.
axes : list of matplotlib.axes._subplots.AxesSubplot
Axes.
lines : list of lists of matplotlib.lines.Line2D
Lines.
"""
zunit, zmult = process_xaxis_unit(zunit)
z = result.z*zmult
x = result.x
Jlat = result.Jlat
ryVa = result.ryVa
fig = plt.figure(figsize=(14, 8), constrained_layout=True)
gs = fig.add_gridspec(2, 2, height_ratios=[2, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[0, 1])
ax4 = fig.add_subplot(gs[1, 1])
lines1 = []
lines2 = []
lines3 = []
lines4 = []
fig.suptitle(suptitle)
ax1.set_title(title, loc='left')
x_plot_list = list(x)[:-1] if exclude_dep else list(x)
for k in x_plot_list:
line, = ax1.plot(z, x[k], label=k)
lines1.append(line)
ax1.legend(loc='upper right')
ax1.set_xticklabels([])
ax1.set_ylabel('$x_k$')
line, = ax2.plot(z, ryVa, c='k')
lines2.append(line)
ax2.set_xlabel(f'$z$ ({zunit})')
ylabel = r'$(y_\mathrm{Va} - y_\mathrm{Va}^\mathrm{eq})'
ylabel += r' / y_\mathrm{Va}^\mathrm{eq}$'
ax2.set_ylabel(ylabel)
ax2.yaxis.set_label_coords(-0.1, 0.5)
line, = ax3.plot(z, Jlat['Va'], c='k', label='Va')
lines3.append(line)
J_plot_list = [k for k in Jlat if k != 'Va']
for k in J_plot_list:
line, = ax3.plot(z, Jlat[k], label=k)
lines3.append(line)
ax3.set_ylabel('$J_k^{lat}$ (mol m$^{-2}$ s$^{-1}$)')
ax3.legend(loc='upper right')
ax3.set_xticklabels([])
ax3.ticklabel_format(axis='y', style='sci', scilimits=(-2, 4))
line, = ax4.plot(z, result.fp*100, 'k')
lines4.append(line)
ax4.set_xlabel(f'$z$ ({zunit})')
ax4.set_ylabel('$f_p$ (%)')
axes = [ax1, ax2, ax3, ax4]
lines = [lines1, lines2, lines3, lines4]
if ylim is not None:
for ax, k in zip([ax1, ax2, ax3, ax4], ['x', 'y0', 'J', 'fp']):
ylim[k] = ylim.get(k, None)
if ylim[k]:
ax.set_ylim(ylim[k])
return fig, axes, lines
[docs]
class StaticProfile:
"""
Static plot of composition and/or flux profile.
Attributes
----------
res : :class:`results.UnitResult`
Simulation results at one time step.
step : int
Time step.
th : float
Time (h).
title : str
Default plot title.
"""
def __init__(self, res, step, th, title):
"""
Class constructor.
Parameters
----------
res : :class:`results.UnitResult`
Simulation results at one time step.
th : float
Time (h).
step : int
Time step.
title : str
Default plot title.
"""
self.res = res
self.step = step
self.th = th
self.title = title
[docs]
def single(self, varname='x', title=None, plot_dep=False, **kwargs):
"""
Plot profile at given time step.
Call :func:`plots.plot_profile_single`.
The plotted variable can be multivalued (dict) or univalued (1D array).
Parameters
----------
varname : str, optional
Name of variable to be plotted. The default is `x`.
plot_dep : bool
Whether to include dependent constituent in `x` and `y` plots.
kwargs
Optional arguments passed to :func:`plot_profile_single`.
Returns
-------
fig, ax : matplotlib figure and axis
Raises
------
Exception
If variable is not recognized.
"""
var = getattr(self.res, varname)
if plot_dep is False:
if varname in ['x', 'y']:
inds = list(var.keys())[:-1]
var = {k: var[k] for k in inds}
if title is None:
title = self.title + f'\nstep {self.step:3}, {self.th:3.1f} h'
fig, ax = plot_profile_single(self.res.z, var, varname, title, **kwargs)
return fig, ax
[docs]
def quartet(self, title=None, **kwargs):
"""
Plot composition and flux profiles at given time step.
Call :func:`plots.plot_profile_quartet`. Plot x, yVa, Jlat, fp.
Parameters
----------
kwargs
Optional arguments passed to :func:`plot_profile_quartet`.
Returns
-------
fig, axes, lines : matplotlib figure, axes and lines
"""
if title is None:
title = f'step {self.step:3}, {self.th:3.1f} h'
suptitle = self.title
fig, axes, lines = plot_profile_quartet(self.res, title, suptitle,
**kwargs)
return fig, axes, lines
[docs]
def calculate_view_limits(results, varname):
"""
Compute y-axis limits adapted to results at all time steps.
Parameters
----------
results : dict of :class:`results.UnitStep`
Contains simulation results with time steps as keys.
varname : str
Name of variable of interest.
Returns
-------
ylim : dict of lists
y-axis view limits for variable of interest.
"""
vals = [getattr(r, varname) for r in results.values()]
varname, _, varmult = process_yaxis_unit(varname)
ymin = 0
ymax = 0
if isinstance(vals[0], np.ndarray):
ymin = min(v.min() for v in vals)
ymax = max(v.max() for v in vals)
elif isinstance(vals[0], dict):
ymin = min(np.array(list(v.values())).min() for v in vals)
ymax = max(np.array(list(v.values())).max() for v in vals)
ymin *= varmult
ymax *= varmult
ymid = (ymin + ymax)/2
new_span = (ymax - ymin)*1.1
# set arbitrary span to avoid matplotlib warning if bottom == top == 0
if new_span == 0:
new_span = 1
ymin = ymid - new_span/2
ymax = ymid + new_span/2
return ymin, ymax
[docs]
class InteractivePlot():
"""Interactive plot of simulation results."""
def __init__(self, varname, comps, results, saved_times, saved_steps,
title):
"""
Class constructor.
Parameters
----------
varname : str
Name of variable to be plotted.
comps : list of str
System constituents.
results : dict of :class:`results.UnitResult`
Simulation results.
saved_steps : list
Steps for which simulation results are stored in steps.
saved_times : 1D array
Times in h (rounded) that correspond to the saved steps.
title : str
Plot title.
"""
self.comps = comps
self.saved_times = saved_times
self.saved_steps = saved_steps
self.num_out = len(saved_steps)
self.i = 0
self.results = deepcopy(results)
self.varname = varname
# Remove dependent component for x profile
if varname == 'x':
for r in self.results.values():
r.x = {k: r.x[k] for k in comps[1:-1]}
xlim = [results[0].z[0]*1e6, results[0].z[-1]*1e6]
ylim = calculate_view_limits(self.results, varname)
self.fig, self.ax, self.lines = self.setup_plot(title, xlim, ylim)
slider_ax = plt.axes([0.4, 0.895, 0.4, 0.02],
facecolor='lightgoldenrodyellow')
self.slider = Slider(slider_ax, '', 0, self.num_out - 1,
valinit=0, valstep=1)
self.slider.valtext.set_visible(False)
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
self.update()
self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
self.slider.on_changed(self.update_slider)
[docs]
def setup_plot(self, title, xlim, ylim):
"""
Make figure with empty lines.
Parameters
----------
title : str
Figure title.
xlim : list of floats
x-axis view limits.
ylim : dict of lists
y-axis view limits.
Returns
-------
fig : matplotlib.figure.Figure
Figure.
ax : matplotlib.axes._subplots.AxesSubplot
Axis.
lines : list of matplotlib.lines.Line2D
Lines.
"""
fig, ax = plt.subplots(figsize=(7, 8/1.5), constrained_layout=True)
fig.suptitle(title)
ax.set_title(' ', loc='left')
lines = []
var = getattr(self.results[0], self.varname)
varname, varunit, varmult = process_yaxis_unit(self.varname)
self.varmult = varmult
ylabel = None
if isinstance(var, np.ndarray):
line, = ax.plot([])
ylabel = f'${varname}$ ({varunit})'
lines = [line]
elif isinstance(var, dict):
for k in var:
color = 'k' if k == 'Va' else None
line, = ax.plot([], c=color, label=k)
lines.append(line)
ylabel = f'${varname}_i$ ({varunit})'
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.legend(loc='upper right')
ax.set_xlabel(r'$z$ ($\mu$m)')
ax.set_ylabel(ylabel)
ax.ticklabel_format(style='sci', scilimits=(-2, 4))
return fig, ax, lines
[docs]
def update_slider(self, val):
"""Set slider to given position (time step) and update the figure."""
self.i = int(val)
self.update()
[docs]
def update(self):
"""Update and draw the figure."""
th = self.saved_times[self.i]
n = self.saved_steps[self.i]
res = self.results[n]
var = getattr(res, self.varname)
z = res.z*1e6
title = self.ax.set_title(f'step {n:3}, {th:3.1f} h', loc='left')
if isinstance(var, np.ndarray):
self.lines[0].set_data(z, var*self.varmult)
elif isinstance(var, dict):
for k, line in zip(var, self.lines):
line.set_data(z, var[k]*self.varmult)
self.fig.canvas.restore_region(self.bg)
for line in self.lines:
self.ax.draw_artist(line)
self.ax.draw_artist(title)
self.fig.canvas.blit(self.fig.bbox)
[docs]
class InteractivePlotQuartet():
"""
Interactive plot of composition and flux profiles.
The layout and quantities plotted are the same as in
:func:`plot_profiles_and_fluxes`.
"""
def __init__(self, comps, results, saved_times, saved_steps, title):
"""
Class constructor.
Parameters
----------
comps : list of str
System constituents.
results : dict of :class:`results.UnitResult`
Simulation results.
saved_steps : list
Steps for which simulation results are stored in steps.
saved_times : 1D array
Times in h (rounded) that correspond to the saved steps.
title : str
Plot title.
"""
self.comps = comps
self.results = deepcopy(results)
self.saved_times = saved_times
self.saved_steps = saved_steps
self.num_out = len(saved_steps)
self.i = 0
# Remove dependent component for x profile
for r in self.results.values():
r.x = {k: r.x[k] for k in comps[1:-1]}
xlim = [results[0].z[0]*1e6, results[0].z[-1]*1e6]
ylim = {varname: calculate_view_limits(self.results, varname)
for varname in ['x', 'ryVa', 'fp', 'Jlat']}
self.fig, self.axes, self.lines = self.setup_plot(title,
xlim, ylim)
self.ax1, self.ax2, self.ax3, self.ax4 = self.axes
self.lines1, self.lines2, self.lines3, self.lines4 = self.lines
slider_ax = plt.axes([0.25, 0.95, 0.5, 0.02],
facecolor='lightgoldenrodyellow')
self.slider = Slider(slider_ax, '', 0, self.num_out - 1,
valinit=0, valstep=1)
self.slider.valtext.set_visible(False)
self.fig.canvas.draw()
self.bg = self.fig.canvas.copy_from_bbox(self.fig.bbox)
self.update()
self.fig.canvas.mpl_connect('scroll_event', self.on_scroll)
self.slider.on_changed(self.update_slider)
[docs]
def setup_plot(self, title, xlim, ylim):
"""
Make figure with empty lines.
Parameters
----------
title : str
Figure title.
xlim : list of floats
x-axis view limits.
ylim : dict of lists
y-axis view limits (see :func:`plot_profiles_and_fluxes` for keys).
Returns
-------
fig : matplotlib.figure.Figure
Figure.
axes : list of matplotlib.axes._subplots.AxesSubplot
Axes.
lines : list of lists of matplotlib.lines.Line2D
Lines.
"""
fig = plt.figure(figsize=(14, 8), constrained_layout=True)
gs = fig.add_gridspec(2, 2, height_ratios=[2, 1])
ax1 = fig.add_subplot(gs[0, 0])
ax2 = fig.add_subplot(gs[1, 0])
ax3 = fig.add_subplot(gs[0, 1])
ax4 = fig.add_subplot(gs[1, 1])
fig.suptitle(title)
ax1.set_title(' ', loc='left')
lines1 = []
lines2 = []
lines3 = []
lines4 = []
for k in self.comps[1:-1]:
line, = ax1.plot([], [], label=k)
lines1.append(line)
ax1.set_xlim(xlim)
ax1.set_ylim(ylim['x'])
ax1.legend(loc='upper right')
ax1.set_xticklabels([])
ax1.set_ylabel('$x_k$')
line, = ax2.plot([], [], c='k')
lines2.append(line)
ax2.set_xlim(xlim)
ax2.set_ylim(ylim['ryVa'])
ax2.set_xlabel(r'$z$ ($\mu$m)')
ylabel = r'$(y_\mathrm{Va} - y_\mathrm{Va}^\mathrm{eq})'
ylabel += r' / y_\mathrm{Va}^\mathrm{eq}$'
ax2.set_ylabel(ylabel)
ax2.yaxis.set_label_coords(-0.1, 0.5)
line, = ax3.plot([], [], 'k', label='Va')
lines3.append(line)
for k in self.comps[1:]:
line, = ax3.plot([], [], label=k)
lines3.append(line)
ax3.set_xlim(xlim)
ax3.set_ylim(ylim['Jlat'])
ax3.set_ylabel('$J_k^{lat}$ (mol m$^{-2}$ s$^{-1}$)')
ax3.legend(loc='upper right')
ax3.set_xticklabels([])
ax3.ticklabel_format(axis='y', style='sci', scilimits=(-2, 4))
line, = ax4.plot([], [], 'k')
lines4.append(line)
ax4.set_xlim(xlim)
ax4.set_ylim([yl*100 for yl in ylim['fp']])
ax4.set_xlabel(r'$z$ ($\mu$m)')
ax4.set_ylabel('$f_p$ (%)')
axes = [ax1, ax2, ax3, ax4]
lines = [lines1, lines2, lines3, lines4]
return fig, axes, lines
[docs]
def update_slider(self, val):
"""Set slider to given position (time step) and update the figure."""
self.i = int(val)
self.update()
[docs]
def update(self):
"""Update and draw the figure."""
th = self.saved_times[self.i]
n = self.saved_steps[self.i]
res = self.results[n]
z = res.z*1e6
x = res.x
Jlat = res.Jlat
title = self.ax1.set_title(f'step {n:3}, {th:3.1f} h', loc='left')
for k, line in zip(self.comps[1:], self.lines1):
line.set_data(z, x[k])
self.lines2[0].set_data(z, res.ryVa)
for k, line in zip(Jlat, self.lines3):
line.set_data(z, Jlat[k])
self.lines4[0].set_data(z, res.fp*100)
self.fig.canvas.restore_region(self.bg)
for ax, lines in zip(self.axes, self.lines):
for line in lines:
ax.draw_artist(line)
self.ax1.draw_artist(title)
self.fig.canvas.blit(self.fig.bbox)
[docs]
def plot_mass_balance(results):
"""
Plot mass balance.
Parameters
----------
results : :class:`results.SimulationResults`
Simulation results.
Returns
-------
fig, ax : matplotlib figure and axis
"""
fig, ax = plt.subplots()
for i, k in zip(range(1, len(results.comps)), results.comps[1:]):
rdN = [results.steps[n].rdIN[i] for n in results.saved_steps]
ax.plot(results.saved_times, rdN, label=k)
rdNv = [results.steps[n].rdINv for n in results.saved_steps]
ax.plot(results.saved_times, rdNv, label='Va + pore + void')
ax.legend()
ax.set_xlabel('$t$ (h)')
ax.set_ylabel('$(N_t - N_0)/N_0$')
return fig, ax