Source code for pleiades.sammy.results.manager

import json
from pathlib import Path

from pleiades.sammy.io.lpt_manager import LptManager
from pleiades.sammy.io.lst_manager import LstManager
from pleiades.sammy.results.models import FitResults, RunResults
from pleiades.utils.logger import loguru_logger

logger = loguru_logger.bind(name=__name__)


[docs] class ResultsManager: """ A class to manage and extract results from SAMMY calculations. Attributes: run_results (RunResults): A container for multiple fit results. """ # Initialize a LptManager object to manage the LPT file def __init__( self, lpt_file_path: Path = None, lst_file_path: Path = None, ): self.run_results = RunResults() # Convert to Path if passed as string if lpt_file_path is not None and not isinstance(lpt_file_path, Path): lpt_file_path = Path(lpt_file_path) if lst_file_path is not None and not isinstance(lst_file_path, Path): lst_file_path = Path(lst_file_path) # Initialize the managers based on the provided file paths # If a file path is provided, process the file and extract results. if lpt_file_path is not None: self.lpt_manager = LptManager(lpt_file_path, self.run_results) if lst_file_path is not None: self.lst_manager = LstManager(lst_file_path, self.run_results)
[docs] def add_fit_result(self, fit_result: FitResults): """Add a FitResults object to the RunResults.""" self.run_results.add_fit_result(fit_result)
[docs] def get_single_fit_results(self, index: int) -> FitResults: """Retrieve a single fit result from the list.""" if self.run_results.fit_results: return self.run_results.fit_results[index] else: raise ValueError("No fit results available.")
[docs] def print_fit_result(self, index: int): """Print a specific fit result in a readable format.""" try: fit_result = self.get_single_fit_results(index) logger.info(f"Fit Result {index}:\n{json.dumps(fit_result.model_dump(), indent=2, default=str)}") except ValueError: logger.warning(f"No fit result found at index {index}.")
[docs] def print_number_of_fit_results(self): """Print the number of fit results.""" num_fit_results = len(self.run_results.fit_results) logger.info(f"Number of fit results: {num_fit_results}")
# Print the run results in a readable format
[docs] def print_run_results(self): """Print the run results in a readable format.""" if self.run_results.fit_results: for fit_result in self.run_results.fit_results: logger.info(f"Fit Result: {fit_result}") else: logger.warning("No fit results available.")
[docs] def print_results_data(self): """Print the results data in a readable format.""" if self.run_results.data and self.run_results.data.data_file: logger.info(f"Results Data: {self.run_results.data}") else: logger.warning("No results data available.")
[docs] def plot_transmission( self, override_data_type: bool = False, show_diff: bool = False, plot_uncertainty: bool = False, figsize=None, title=None, xscale="linear", yscale="linear", data_color="#433E3F", final_color="#ff6361", show=True, ): """ Plot the transmission data from the results. Args: override_data_type (bool): Force plotting even if data type is not transmission. show_diff (bool): If True, plot the residuals. plot_uncertainty (bool): If True, plot error bars. figsize (tuple): Figure size (width, height) in inches. title (str): Plot title. xscale (str): X-axis scale ('linear' or 'log'). yscale (str): Y-axis scale ('linear' or 'log'). data_color (str): Color for experimental data points. final_color (str): Color for fitted theoretical curve. show (bool): If True, display the plot. If False, return figure object. Returns: matplotlib.figure.Figure: The figure object if show=False, None otherwise. """ if self.run_results.data: # Check if data type is transmission if self.run_results.data.data_type == "TRANSMISSION" or override_data_type: return self.run_results.data.plot_transmission( show_diff=show_diff, plot_uncertainty=plot_uncertainty, figsize=figsize, title=title, xscale=xscale, yscale=yscale, data_color=data_color, final_color=final_color, show=show, ) else: logger.warning("Data type is not transmission. Cannot plot.") return None else: logger.warning("No results data available for plotting.") return None
[docs] def plot_cross_section( self, override_data_type: bool = False, show_diff: bool = False, plot_uncertainty: bool = False ): """Plot the cross-section data from the results.""" if self.run_results.data: if self.run_results.data.data_type == "CROSS_SECTION" or override_data_type: self.run_results.data.plot_cross_section(show_diff=show_diff, plot_uncertainty=plot_uncertainty) else: logger.warning("Data type is not cross-section. Cannot plot.") else: logger.warning("No results data available for plotting.")
[docs] def get_data(self): """Get the data from the results.""" if self.run_results.data: return self.run_results.data else: logger.warning("No results data available.") return None