axtreme.plotting.gp_fit¶
Plotting module for visualizing how well the GP fits the data.
Functions
|
Plots a model with 1d in put, and any number of outputs.. |
|
Plot the GP fit for the given metrics over the 2D search space. |
Plot the GP fit for the given trial index and metrics over the 2D search space from experiment. |
|
|
Creates a figure with the functions in funcs ploted over the search_space. |
|
Make a scattter plot of a metric for the training data of the model. |
- axtreme.plotting.gp_fit.plot_1d_model(model: SingleTaskGP, X: Tensor | None = None, ax: None | Axes = None) Axes ¶
Plots a model with 1d in put, and any number of outputs..
- Parameters:
model – Only SingleTaskGp is supported an training data is extracted from the model.
X – (n,1): Linspace of [0,1] is used by default. Only 1d is currently supported.
ax – will plot to this axis if provied
- axtreme.plotting.gp_fit.plot_gp_fits_2d_surface(model_bridge: TorchModelBridge, search_space: SearchSpace, metrics: dict[str, Callable[[ndarray[tuple[int, int], dtype[float64]]], ndarray[tuple[int], dtype[float64]]]] | None = None, num_points: int = 101, *, show_bounds: bool = True, show_point_idxs: bool = False) Figure ¶
Plot the GP fit for the given metrics over the 2D search space.
- Parameters:
model_bridge – The model bridge used to make predictions.
search_space – The search space over which the functions are to be evaluated and plotted.
metrics – A dictionary of metrics to plot. The keys are the names of the metrics in the model bridge model and the values are callables that return the metric value for a given input.
num_points – The number of points in each dimension to evaluate the functions at.
show_bounds – Whether to show the upper and lower bounds(std) of the GP.
show_point_idxs – Whether to show the indices of the points (their order of when each point is added to the GP).
- axtreme.plotting.gp_fit.plot_gp_fits_2d_surface_from_experiment(experiment: Experiment, trial_index: int, metrics: dict[str, Callable[[ndarray[tuple[int, int], dtype[float64]]], ndarray[tuple[int], dtype[float64]]]] | None = None, show_bounds: bool = True, show_point_idxs: bool = True) Figure ¶
Plot the GP fit for the given trial index and metrics over the 2D search space from experiment.
- Parameters:
experiment – The experiment used to make predictions.
trial_index – The index of the trial to plot.
metrics – A dictionary of metrics to plot. The keys are the names of the metrics in the model bridge model and the values are callables that return the metric value for a given input.
show_bounds – Whether to show the upper and lower bounds(std) of the GP.
show_point_idxs – Whether to show the indices of the points (their order of when each point is added to the GP).
- axtreme.plotting.gp_fit.plot_surface_over_2d_search_space(search_space: SearchSpace, funcs: list[Callable[[ndarray[tuple[int, int], dtype[float64]]], ndarray[tuple[int], dtype[float64]]]], colors: list[str] | None = None, num_points: int = 101) Figure ¶
Creates a figure with the functions in funcs ploted over the search_space.
Note
Currently only support search spaces with 2 parameters.
- Parameters:
search_space – The search space over which the functions are to be evaluated and plotted.
funcs – A list of callables that take in a numpy array with shape (num_values, num_parameters=2 ) and return a numpy array with (num_values) elements.
colors – A list of colors to use for each function. If None, will use default Plotly colors.
num_points – The number of points in each dimension to evaluate the functions at.
- axtreme.plotting.gp_fit.scatter_plot_training(model_bridge: TorchModelBridge, metric_name: str, axis: tuple[int, int] = (0, 1), figure: Figure | None = None, *, error_bars: bool = True, error_bar_confidence_interval: float = 0.95, show_indices: bool = True) Figure ¶
Make a scattter plot of a metric for the training data of the model.
- Parameters:
model_bridge – The model bridge used to make predictions.
metric_name – The name of the metric to plot. Must match the name of a metric in the model.
axis – The axis of the input space to plot the scatter plot in
figure – The figure to add the scatter plot to. If None, a new figure is created.
error_bars – Whether to add error bars to the plot.
error_bar_confidence_interval – The confidence interval the error bars in the scatter plot represents.
show_indices – Whether to show the indices of the points (their order of when each point is added to the GP).