Source code for tfep.utils.plumed.plot

#!/usr/bin/env python


# =============================================================================
# MODULE DOCSTRING
# =============================================================================

"""
Provide utility functions to read output files generated by PLUMED.
"""


# =============================================================================
# GLOBAL IMPORTS
# =============================================================================

from matplotlib import pyplot as plt


# =============================================================================
# UTILITY FUNCTIONS
# =============================================================================

[docs] def plot_trajectory( data, col_names=None, time_unit=None, stride=1, axes=None, plot_kwargs=None ): """Plot the trajectory of all the given columns in time. Parameters ---------- data : Dict[str, numpy.ndarray] The named columns of the PLUMED output file. For the expected format, see ``plumedwrapper.io.read_columns``. There must be a 'time' column in the data. col_names : str or List[str], optional A single name or a list of the column names to plot. If not given, all columns are plotted. time_unit : str, optional The unit of time (e.g., 'ps', or 'ns') of the time dimension used for plotting. stride : int, optional Only data points every ``stride`` entries are plotted. Default is 1. axes : matplotlib.pyplot.Axes, optional Optionally, an existing Axes object can be passed, otherwise this function will create a new figure. plot_kwargs : Dict, optional Other keyword arguments to pass to matplotlib.pyplot.plot. Returns ------- axes : matplotlib.pyplot.Axes The Axes object use for plotting. """ # Instantiate mutable defaults. if plot_kwargs is None: plot_kwargs = {} # Create a new Figure if no Axes is passed. if axes is None: fig, axes = plt.subplots() # If no column names are given, we plot all of them. if col_names is None: col_names = [k for k in data.keys() if k != 'time'] elif isinstance(col_names, str): # Make sure col_names is a list. col_names = [col_names] # Convert time dimension. Plumed plot time in femtoseconds. if time_unit is None or time_unit == 'fs': time_unit = 'fs' time = data['time'] else: from pint import UnitRegistry time = (data['time'] * UnitRegistry().fs).to(time_unit).magnitude # Plot all the trajectories. for col_name in col_names: axes.plot(time[::stride], data[col_name][::stride], label=col_name, **plot_kwargs) # Fix labels. axes.set_xlabel(f'simulation time [{time_unit}]') # If there are multiple trajectories, use a legend rather than a label. if len(col_names) == 1: axes.set_ylabel(col_names[0]) else: axes.legend() # There's no point in making the x axis start from negative numbers. axes.set_xlim((0, time[-1])) return axes