"""Plotting utilities for stochkin.
Every public function in this module applies the stochkin publication
style (:func:`~stochkin.style.set_publication_style`) before drawing,
so that all output figures are publication-ready by default.
The style mirrors the Matplotlib rcParams used in the FES_2D.ipynb
analysis notebook (Arial font, inward ticks, white background, 300 dpi,
single-column figure width).
"""
from __future__ import annotations
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap, BoundaryNorm, LogNorm, Normalize
from .style import (
publication_style,
LABEL_SIZE,
TICK_SIZE,
LEGEND_SIZE,
CBAR_LABEL_SIZE,
CBAR_TICK_SIZE,
TITLE_SIZE,
)
# ── helper ────────────────────────────────────────────────────────────
def _apply_pub_axes(ax, xlabel=None, ylabel=None, title=None):
"""Apply consistent tick/label formatting to an Axes object."""
ax.tick_params(axis="both", which="major", labelsize=TICK_SIZE,
direction="in", length=5, width=0.8)
ax.tick_params(axis="both", which="minor",
direction="in", length=3, width=0.8)
if xlabel is not None:
ax.set_xlabel(xlabel, size=LABEL_SIZE)
if ylabel is not None:
ax.set_ylabel(ylabel, size=LABEL_SIZE)
if title is not None:
ax.set_title(title, size=TITLE_SIZE)
def _apply_pub_cbar(cbar, label=None):
"""Format a colorbar to match the publication style."""
cbar.ax.tick_params(labelsize=CBAR_TICK_SIZE)
if label is not None:
cbar.set_label(label, fontsize=CBAR_LABEL_SIZE)
# =====================================================================
# plot_results
# =====================================================================
[docs]
def plot_results(times, positions, velocities, energies, bins=50):
"""Basic diagnostic plots for a 2D Langevin trajectory.
Four panels: trajectory, energy vs time, position histogram,
energy histogram.
"""
with publication_style():
# Trajectory
fig, ax = plt.subplots(figsize=(3.3, 2.8))
ax.plot(positions[:, 0], positions[:, 1], "-o", markersize=2)
_apply_pub_axes(ax, "x₁", "x₂", "2D Langevin trajectory")
fig.tight_layout()
# Energy vs time
fig2, ax2 = plt.subplots(figsize=(3.3, 2.8))
ax2.plot(times, energies, "-")
_apply_pub_axes(ax2, "time", "Total Energy", "Energy vs Time")
fig2.tight_layout()
# Position histogram
fig3, ax3 = plt.subplots(figsize=(3.3, 2.8))
h = ax3.hist2d(positions[:, 0], positions[:, 1], bins=bins, cmap="viridis")
_apply_pub_cbar(fig3.colorbar(h[3], ax=ax3), label="counts")
_apply_pub_axes(ax3, "x₁", "x₂", "Position distribution")
fig3.tight_layout()
# Energy histogram
fig4, ax4 = plt.subplots(figsize=(3.3, 2.8))
ax4.hist(energies, bins=bins, alpha=0.7)
_apply_pub_axes(ax4, "Energy", "Frequency", "Energy distribution")
fig4.tight_layout()
plt.show()
# =====================================================================
# plot_mfpt_matrix
# =====================================================================
[docs]
def plot_mfpt_matrix(
mfpt_network_results,
log10=False,
cmap="magma",
figsize=(3.3, 2.8),
title="MFPT matrix τ(i→j)",
):
"""Plot MFPT(i→j) as a heatmap.
Parameters
----------
mfpt_network_results : dict
Output of ``compute_mfpt_network``.
log10 : bool
If True, plot log₁₀(τᵢⱼ).
cmap : str
Colormap.
figsize : tuple
Figure size.
title : str
Plot title.
"""
with publication_style():
tau = np.array(mfpt_network_results["mfpt_matrix"], dtype=float)
n = tau.shape[0]
tau_masked = np.ma.masked_array(tau, mask=False)
for i in range(n):
tau_masked[i, i] = np.ma.masked
if log10:
with np.errstate(divide="ignore", invalid="ignore"):
data = np.log10(tau_masked)
clabel = r"$\log_{10} \tau_{ij}$"
else:
data = tau_masked
clabel = r"$\tau_{ij}$"
fig, ax = plt.subplots(figsize=figsize)
im = ax.imshow(data, origin="lower", cmap=cmap, aspect="equal")
_apply_pub_cbar(fig.colorbar(im, ax=ax), label=clabel)
_apply_pub_axes(ax, "j (target basin)", "i (start basin)", title)
ax.set_xticks(np.arange(n))
ax.set_yticks(np.arange(n))
fig.tight_layout()
plt.show()
# =====================================================================
# plot_fp_solution_vs_boltzmann
# =====================================================================
[docs]
def plot_fp_solution_vs_boltzmann(
fp_result,
beta=1.0,
log=True,
cmap="viridis",
figsize=None,
):
"""Compare FP steady-state with Boltzmann distribution.
Parameters
----------
fp_result : dict
Output of ``solve_fp_steady_state``.
beta : float
Inverse temperature.
log : bool
If True, also show log-ratio panel.
cmap : str
Colormap.
figsize : tuple, optional
Figure size (default auto-sized for 2 or 3 panels).
"""
with publication_style():
xs = fp_result["xs"]
ys = fp_result["ys"]
p_grid = fp_result["p_grid"]
U_grid = fp_result["U_grid"]
X, Y = np.meshgrid(xs, ys, indexing="ij")
boltz = np.exp(-beta * U_grid)
boltz /= np.sum(boltz)
ncols = 3 if log else 2
if figsize is None:
figsize = (3.3 * ncols, 2.8)
fig, axes = plt.subplots(1, ncols, figsize=figsize)
im0 = axes[0].contourf(X, Y, p_grid, levels=40, cmap=cmap)
_apply_pub_cbar(fig.colorbar(im0, ax=axes[0]))
_apply_pub_axes(axes[0], "x", "y", "FP steady p(x)")
im1 = axes[1].contourf(X, Y, boltz, levels=40, cmap=cmap)
_apply_pub_cbar(fig.colorbar(im1, ax=axes[1]))
_apply_pub_axes(axes[1], "x", "y", r"Boltzmann $\propto e^{-\beta U}$")
if log:
with np.errstate(divide="ignore"):
log_ratio = np.log(p_grid + 1e-30) - np.log(boltz + 1e-30)
im2 = axes[2].contourf(X, Y, log_ratio, levels=40, cmap="coolwarm")
_apply_pub_cbar(fig.colorbar(im2, ax=axes[2]))
_apply_pub_axes(axes[2], "x", "y",
r"$\ln p_{\rm FP} - \ln p_{\rm Boltz}$")
fig.tight_layout()
plt.show()
# =====================================================================
# plot_basin_network
# =====================================================================
[docs]
def plot_basin_network(
basin_network,
levels=40,
fes_cmap="viridis",
basin_cmap="tab20",
alpha_basins=0.35,
show_minima=True,
annotate_ids=True,
figsize=(3.3, 2.8),
):
"""Plot the FES with overlaid basin partition and minima.
Parameters
----------
basin_network : BasinNetwork
As returned by ``detect_basins_for_mfpt``.
levels : int
Number of contour levels.
fes_cmap, basin_cmap : str
Colormaps.
alpha_basins : float
Basin overlay transparency.
show_minima, annotate_ids : bool
Show minimum markers / basin-id labels.
figsize : tuple
Figure size.
"""
with publication_style():
xs = basin_network.xs
ys = basin_network.ys
U = basin_network.U
labels = basin_network.labels
X, Y = np.meshgrid(xs, ys, indexing="ij")
Z = np.ma.masked_invalid(U)
fig, ax = plt.subplots(figsize=figsize)
cf = ax.contourf(X, Y, Z, levels=levels, cmap=fes_cmap)
_apply_pub_cbar(fig.colorbar(cf, ax=ax), label="FES (kJ/mol)")
label_mask = np.ma.masked_where(labels < 0, labels)
n_basins = basin_network.n_basins
bounds = np.arange(-1, n_basins + 1, 1)
n_bins = len(bounds) - 1
cmap_b = plt.cm.get_cmap("tab20", n_bins)
norm = BoundaryNorm(bounds, ncolors=n_bins)
ax.imshow(
label_mask.T,
origin="lower",
extent=[xs[0], xs[-1], ys[0], ys[-1]],
cmap=cmap_b,
norm=norm,
alpha=alpha_basins,
aspect="auto",
)
if show_minima:
for b in basin_network.basins:
ax.scatter(
b.minimum[0], b.minimum[1],
s=60, c="k", edgecolors="white", linewidths=1.0, zorder=5,
)
if annotate_ids:
ax.text(
b.minimum[0], b.minimum[1], f"{b.id}",
color="white", fontsize=9, ha="center", va="center",
zorder=6,
)
_apply_pub_axes(ax, "x", "y", "Basins on FES")
fig.tight_layout()
plt.show()
# =====================================================================
# plot_central_well_barrier_ring
# =====================================================================
[docs]
def plot_central_well_barrier_ring(
a=1.0,
b=1.0,
A=0.5,
sigma=0.5,
r_max=2.0,
n_points=400,
grid_size=200,
):
"""Plot the radial profile + 2D landscape of the ring-barrier potential."""
with publication_style():
rs = np.linspace(0, r_max, n_points)
U_radial = b * rs**4 - a * rs**2 - A * np.exp(-rs**2 / sigma**2)
fig, ax = plt.subplots(figsize=(3.3, 2.8))
ax.plot(rs, U_radial)
_apply_pub_axes(ax, "r", "U(r)", "Radial profile")
fig.tight_layout()
x = np.linspace(-r_max, r_max, grid_size)
y = np.linspace(-r_max, r_max, grid_size)
X, Y = np.meshgrid(x, y)
R2 = X**2 + Y**2
U = b * R2**2 - a * R2 - A * np.exp(-R2 / sigma**2)
fig2, ax2 = plt.subplots(figsize=(3.3, 2.8))
cp = ax2.contourf(X, Y, U, levels=50, cmap="viridis")
_apply_pub_cbar(fig2.colorbar(cp, ax=ax2), label="U(x)")
_apply_pub_axes(ax2, "x₁", "x₂", "2D potential")
ax2.set_aspect("equal")
fig2.tight_layout()
plt.show()
# =====================================================================
# plot_2d_fes – PLUMED 2D FES contour (like notebook plot_2d_contourf_MAX)
# =====================================================================
[docs]
def plot_2d_fes(
data_path,
*,
save_path=None,
levels=10,
fes_max=None,
delta=90,
reweight=False,
invert=False,
xlim=None,
ylim=None,
auto_zoom=True,
zoom_pad=0.08,
cmap_name="rainbow",
swap_xy=True,
show_cbar=True,
cbar_label="FES (kJ/mol)",
xlabel=None,
ylabel=None,
pathways=None,
pathway_style=None,
pathway_markers=False,
pathway_every=1,
pathway_labels=None,
figsize=(4, 3),
ax=None,
):
"""Plot a PLUMED-format 2D FES as a filled contour.
This is a cleaned-up version of the ``plot_2d_contourf_MAX`` helper
from the FES_2D.ipynb notebook, with publication styling applied
automatically.
Parameters
----------
data_path : str or Path
Path to a PLUMED ``sum_hills`` 2D FES file.
save_path : str or Path, optional
If given, save the figure to this path at 300 dpi.
levels : int
Number of contour levels.
fes_max : float, optional
Fixed colour-scale maximum (kJ/mol). Also used as the masking
threshold if *delta* is not explicitly set.
delta : float
Percentile threshold for masking (default 90).
reweight : bool
Subtract the minimum from the FES.
swap_xy : bool
Swap x/y axes (matches the notebook convention CN(Cl) vs CN(O)).
pathways : list, optional
MFEP overlay data (arrays or file paths).
ax : matplotlib Axes, optional
Draw onto an existing Axes instead of creating a new figure.
Returns
-------
fig, ax : if *ax* was None
ax : if an existing *ax* was passed
"""
with publication_style():
data = np.genfromtxt(data_path, comments="#")
data = data[np.all(np.isfinite(data[:, :3]), axis=1)]
x, y, z = data[:, 0], data[:, 1], data[:, 2].copy()
if reweight:
z -= np.min(z)
x_unique, y_unique = np.unique(x), np.unique(y)
nx, ny = len(x_unique), len(y_unique)
Z = z.reshape(ny, nx)
if swap_xy:
X, Y = np.meshgrid(y_unique, x_unique)
Zplot, x_axis, y_axis = Z.T, y_unique, x_unique
_xlabel = xlabel or "CN(Cl)"
_ylabel = ylabel or "CN(O)"
else:
X, Y = np.meshgrid(x_unique, y_unique)
Zplot, x_axis, y_axis = Z, x_unique, y_unique
_xlabel = xlabel or "CV₁"
_ylabel = ylabel or "CV₂"
thr = float(fes_max) if fes_max is not None else float(np.percentile(Zplot, delta))
Zm = np.ma.masked_where(Zplot >= thr, Zplot)
cmap = plt.get_cmap(cmap_name).copy()
cmap.set_bad("white")
vmax = float(fes_max) if fes_max is not None else (
float(np.ma.max(Zm)) if np.ma.count(Zm) > 0 else float(np.max(Zplot))
)
lev = np.linspace(0.0, vmax, int(levels))
norm = Normalize(vmin=0.0, vmax=vmax)
own_fig = ax is None
if own_fig:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
cf = ax.contourf(X, Y, Zm, levels=lev, cmap=cmap, norm=norm, extend="max")
# Pathways overlay
if pathways is not None:
if pathway_style is None:
pathway_style = dict(lw=1.0, alpha=1.0)
if not isinstance(pathways, (list, tuple)):
pathways = [pathways]
for k, p in enumerate(pathways):
P = np.genfromtxt(p, comments="#") if isinstance(p, str) else np.asarray(p)
px, py = P[:, 1], P[:, 2]
xplot, yplot = (py, px) if swap_xy else (px, py)
ax.plot(xplot, yplot, color="k", **pathway_style)
if pathway_markers:
ax.plot(xplot[::pathway_every], yplot[::pathway_every],
ls="none", marker="o", ms=3, color="k",
alpha=pathway_style.get("alpha", 1.0))
if show_cbar:
cbar = fig.colorbar(cf, ax=ax)
_apply_pub_cbar(cbar, label=cbar_label)
# Auto zoom
if auto_zoom and xlim is None and ylim is None and np.ma.count(Zm) > 0:
jj, ii = np.where(~Zm.mask)
xmin, xmax = x_axis[ii.min()], x_axis[ii.max()]
ymin, ymax = y_axis[jj.min()], y_axis[jj.max()]
dx = (xmax - xmin) or 1.0
dy = (ymax - ymin) or 1.0
ax.set_xlim(xmin - zoom_pad * dx, xmax + zoom_pad * dx)
ax.set_ylim(ymin - zoom_pad * dy, ymax + zoom_pad * dy)
if xlim is not None:
ax.set_xlim(*xlim)
if ylim is not None:
ax.set_ylim(*ylim)
if invert:
ax.invert_xaxis()
_apply_pub_axes(ax, _xlabel, _ylabel)
fig.tight_layout()
if save_path:
fig.savefig(save_path, dpi=300)
if own_fig:
return fig, ax
return ax
# =====================================================================
# draw_barrier_arrows – ported from FES_2D.ipynb
# =====================================================================
[docs]
def draw_barrier_arrows(
ax,
x,
y_top,
y0,
*,
y_bottom=None,
label=True,
label_fmt="{:.1f}",
label_side="right",
label_dx=0.05,
label_dy=0.0,
label_x_overrides=None,
label_y_overrides=None,
y_override_mode="abs",
arrowprops=None,
text_kwargs=None,
):
"""Draw double-headed barrier-height arrows on an Axes.
Ported from the ``draw_barrier_arrows`` helper in FES_2D.ipynb.
Parameters
----------
ax : matplotlib Axes
x : float or array
Horizontal position(s) of the arrow(s).
y_top : float or array
Top of each arrow (barrier peak FES value).
y0 : float
Default baseline (bottom) for the arrows.
y_bottom : float or array, optional
Per-arrow bottom override.
label : bool
Annotate the barrier height ΔF.
label_side : ``'right'`` or ``'left'``
Side on which to place the label text.
"""
x = np.atleast_1d(x)
y_top = np.atleast_1d(y_top) if np.ndim(y_top) else np.full_like(x, y_top, dtype=float)
y_bottom = (np.atleast_1d(y_bottom) if np.ndim(y_bottom) else np.full_like(x, y_bottom, dtype=float)) if y_bottom is not None else np.full_like(x, y0, dtype=float)
base_ap = dict(arrowstyle="<->", ls="--", lw=0.75, color="0.2")
if arrowprops:
base_ap.update(arrowprops)
side = label_side.lower()
default_dx = abs(label_dx) if side == "right" else -abs(label_dx)
default_ha = "left" if side == "right" else "right"
base_tk = dict(ha=default_ha, va="center", fontsize=LEGEND_SIZE)
if text_kwargs:
base_tk.update(text_kwargs)
label_x_overrides = label_x_overrides or {}
label_y_overrides = label_y_overrides or {}
out = []
for i, (xi, yb, yt) in enumerate(zip(x, y_bottom, y_top)):
ann = ax.annotate("", xy=(xi, yb), xytext=(xi, yt), arrowprops=base_ap)
txt = None
if label:
dF = yt - yb
mid_y = 0.5 * (yb + yt) + label_dy
dx = label_x_overrides.get(i, default_dx)
if i in label_y_overrides:
mid_y = mid_y * label_y_overrides[i] if y_override_mode == "mul" else label_y_overrides[i]
txt = ax.text(xi + dx, mid_y, label_fmt.format(dF), **base_tk)
out.append((ann, txt))
return out