Skip to content
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
501e5d5
Add grid interpolation support to Function class with from_grid() method
Copilot Nov 14, 2025
3d1337a
Add multi-dimensional drag coefficient support to Flight class and in…
Copilot Nov 14, 2025
dc7ad73
Run ruff format on modified files
Copilot Nov 14, 2025
560ef80
MNt: refactoring get_drag_coefficient in flight.py
aZira371 Nov 15, 2025
23ac66f
MNT: refactoring in flight.py and lint corrections to function.py and…
aZira371 Nov 19, 2025
365c2da
MNT: refactoring flight.py to remove unused parameters
aZira371 Nov 24, 2025
dc01807
MNT: correction of docstring function.py
aZira371 Nov 24, 2025
0b88906
MNT: make format and lint corrections to function.py
aZira371 Nov 24, 2025
5fe4625
MNT: pylint adjustments for new methods in function.py
aZira371 Nov 24, 2025
d832bf2
MNt: make format after previous change to function.py
aZira371 Nov 24, 2025
3f76344
MNT: removed Re where unused in test_multidim_drag.py
aZira371 Nov 24, 2025
74fe825
TST: Add tests for shepard_fallback in test_function_grid.py (#879)
Copilot Nov 27, 2025
e4053ac
TST: test_multidim_drag.py
aZira371 Nov 30, 2025
81f7bfa
MNT: addition of is_multidimensional to function.py
aZira371 Nov 30, 2025
ecb90ed
MNT: Added validation in from_grid in function.py to raise a ValueErr…
aZira371 Nov 30, 2025
e3fcad1
ENH: Added alpha-sensitive flight fixtures to flight_fixtures.py
aZira371 Nov 30, 2025
84350e6
MNT: renamed linear_grid to regular_grid for easy to understand nomen…
aZira371 Nov 30, 2025
953f796
MNT: replaced the broad except Exception: with except (TypeError, Va…
aZira371 Nov 30, 2025
d4d3771
TST: added from_grid unit tests to cover constructor-level validation…
aZira371 Nov 30, 2025
d646e46
MNT: format and lint update to test_function_from_grid.py
aZira371 Nov 30, 2025
cab76fa
DOC: changelog.md update for multidim drag
aZira371 Nov 30, 2025
fe2052b
Merge branch 'develop' into copilot/enhance-drag-curve-functionality
aZira371 Nov 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ good-names=FlightPhases,
center_of_mass_without_motor_to_CDM,
motor_center_of_dry_mass_to_CDM,
generic_motor_cesaroni_M1520,
Re, # Reynolds number

# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
Expand Down
278 changes: 276 additions & 2 deletions rocketpy/mathutils/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
LinearNDInterpolator,
NearestNDInterpolator,
RBFInterpolator,
RegularGridInterpolator,
)

from rocketpy.plots.plot_helpers import show_or_save_plot
Expand All @@ -43,6 +44,7 @@
"spline": 3,
"shepard": 4,
"rbf": 5,
"linear_grid": 6,
}
EXTRAPOLATION_TYPES = {"zero": 0, "natural": 1, "constant": 2}

Expand Down Expand Up @@ -449,6 +451,41 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab

self._interpolation_func = rbf_interpolation

elif interpolation == 6: # linear_grid (RegularGridInterpolator)
# For grid interpolation, the actual interpolator is stored separately
# This function is a placeholder that should not be called directly
# since __get_value_opt_grid is used instead
if hasattr(self, "_grid_interpolator"):

def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument
return self._grid_interpolator(x)

self._interpolation_func = grid_interpolation
else:
# Fallback to shepard if grid interpolator not available
warnings.warn(
"Grid interpolator not found, falling back to shepard interpolation"
)

def shepard_fallback(x, x_min, x_max, x_data, y_data, _):
# pylint: disable=unused-argument
arg_qty, arg_dim = x.shape
result = np.empty(arg_qty)
x = x.reshape((arg_qty, 1, arg_dim))
sub_matrix = x_data - x
distances_squared = np.sum(sub_matrix**2, axis=2)
zero_distances = np.where(distances_squared == 0)
valid_indexes = np.ones(arg_qty, dtype=bool)
valid_indexes[zero_distances[0]] = False
weights = distances_squared[valid_indexes] ** (-1.5)
numerator_sum = np.sum(y_data * weights, axis=1)
denominator_sum = np.sum(weights, axis=1)
result[valid_indexes] = numerator_sum / denominator_sum
result[~valid_indexes] = y_data[zero_distances[1]]
return result

self._interpolation_func = shepard_fallback

else:
raise ValueError(f"Interpolation {interpolation} method not recognized.")

Expand Down Expand Up @@ -635,6 +672,66 @@ def __get_value_opt_nd(self, *args):

return result

def __get_value_opt_grid(self, *args): # pylint: disable=unused-private-member
"""Evaluate the Function using RegularGridInterpolator for structured grids.

This method is dynamically assigned in from_grid() class method.

Parameters
----------
args : tuple
Values where the Function is to be evaluated. Must match the number
of dimensions of the grid.

Returns
-------
result : scalar or ndarray
Value of the Function at the specified points.
"""
# Check if we have the grid interpolator
if not hasattr(self, "_grid_interpolator"):
raise RuntimeError(
"Grid interpolator not initialized. Use from_grid() to create "
"a Function with grid interpolation."
)

# Convert args to appropriate format for RegularGridInterpolator
# RegularGridInterpolator expects points as (N, ndim) array
if len(args) != self.__dom_dim__:
raise ValueError(
f"Expected {self.__dom_dim__} arguments but got {len(args)}"
)

# Handle single point evaluation
point = np.array(args).reshape(1, -1)

# Handle extrapolation based on the extrapolation setting
if self.__extrapolation__ == "constant":
# Clamp point to grid boundaries for constant extrapolation
for i, axis in enumerate(self._grid_axes):
point[0, i] = np.clip(point[0, i], axis[0], axis[-1])
result = self._grid_interpolator(point)
elif self.__extrapolation__ == "zero":
# Check if point is outside bounds
outside_bounds = False
for i, axis in enumerate(self._grid_axes):
if point[0, i] < axis[0] or point[0, i] > axis[-1]:
outside_bounds = True
break
if outside_bounds:
result = np.array([0.0])
else:
result = self._grid_interpolator(point)
else:
# Natural or other extrapolation - use interpolator directly
result = self._grid_interpolator(point)

# Return scalar for single evaluation
if result.size == 1:
return float(result[0])

return result

def __determine_1d_domain_bounds(self, lower, upper):
"""Determine domain bounds for 1-D function discretization.

Expand Down Expand Up @@ -3891,11 +3988,11 @@ def __validate_interpolation(self, interpolation):
elif self.__dom_dim__ > 1:
if interpolation is None:
interpolation = "shepard"
if interpolation.lower() not in ["shepard", "linear", "rbf"]:
if interpolation.lower() not in ["shepard", "linear", "rbf", "linear_grid"]:
warnings.warn(
(
"Interpolation method set to 'shepard'. The methods "
"'linear', 'shepard' and 'rbf' are supported for "
"'linear', 'shepard', 'rbf' and 'linear_grid' are supported for "
"multiple dimensions."
),
)
Expand Down Expand Up @@ -3950,6 +4047,183 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument
"extrapolation": self.__extrapolation__,
}

@classmethod
def from_grid(
cls,
grid_data,
axes,
inputs=None,
outputs=None,
interpolation="linear_grid",
extrapolation="constant",
**kwargs,
): # pylint: disable=too-many-statements #TODO: Refactor this method into smaller methods
"""Creates a Function from N-dimensional grid data.

This method is designed for structured grid data, such as CFD simulation
results where values are computed on a regular grid. It uses
scipy.interpolate.RegularGridInterpolator for efficient interpolation.

Parameters
----------
grid_data : ndarray
N-dimensional array containing the function values on the grid.
For example, for a 3D function Cd(M, Re, α), this would be a 3D array
where grid_data[i, j, k] = Cd(M[i], Re[j], α[k]).
axes : list of ndarray
List of 1D arrays defining the grid points along each axis.
Each array should be sorted in ascending order.
For example: [M_axis, Re_axis, alpha_axis].
inputs : list of str, optional
Names of the input variables. If None, generic names will be used.
For example: ['Mach', 'Reynolds', 'Alpha'].
outputs : str, optional
Name of the output variable. For example: 'Cd'.
interpolation : str, optional
Interpolation method. Default is 'linear_grid'.
Currently only 'linear_grid' is supported for grid data.
extrapolation : str, optional
Extrapolation behavior. Default is 'constant', which clamps to edge values.
'constant': Use nearest edge value for out-of-bounds points.
'zero': Return zero for out-of-bounds points.
**kwargs : dict, optional
Additional arguments passed to the Function constructor.

Returns
-------
Function
A Function object using RegularGridInterpolator for evaluation.

Notes
-----
- Grid data must be on a regular (structured) grid.
- For unstructured data, use the regular Function constructor with
scattered points.
- Extrapolation with 'constant' mode uses the nearest edge values,
which is appropriate for aerodynamic coefficients where extrapolation
beyond the data range should be avoided.

Examples
--------
>>> import numpy as np
>>> # Create 3D drag coefficient data
>>> mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0])
>>> reynolds = np.array([1e5, 5e5, 1e6])
>>> alpha = np.array([0.0, 2.0, 4.0, 6.0])
>>> # Create a simple drag coefficient function
>>> M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing='ij')
>>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A
>>> # Create Function object
>>> cd_func = Function.from_grid(
... cd_data,
... [mach, reynolds, alpha],
... inputs=['Mach', 'Reynolds', 'Alpha'],
... outputs='Cd'
... )
>>> # Evaluate at a point
>>> cd_func(1.2, 3e5, 3.0)
0.48000000000000004

"""
# Validate inputs
if not isinstance(grid_data, np.ndarray):
grid_data = np.array(grid_data)

if not isinstance(axes, (list, tuple)):
raise ValueError("axes must be a list or tuple of 1D arrays")

# Ensure all axes are numpy arrays
axes = [
np.array(axis) if not isinstance(axis, np.ndarray) else axis
for axis in axes
]

# Check dimensions match
if len(axes) != grid_data.ndim:
raise ValueError(
f"Number of axes ({len(axes)}) must match grid_data dimensions "
f"({grid_data.ndim})"
)

# Check each axis matches corresponding grid dimension
for i, axis in enumerate(axes):
if len(axis) != grid_data.shape[i]:
raise ValueError(
f"Axis {i} has {len(axis)} points but grid dimension {i} "
f"has {grid_data.shape[i]} points"
)

# Set default inputs if not provided
if inputs is None:
inputs = [f"x{i}" for i in range(len(axes))]
elif len(inputs) != len(axes):
raise ValueError(
f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})"
)

# Create a new Function instance
func = cls.__new__(cls)

# Store grid-specific data first
func._grid_axes = axes
func._grid_data = grid_data

# Create RegularGridInterpolator
# We handle extrapolation manually in __get_value_opt_grid,
# so we set bounds_error=False and let it extrapolate linearly
# (which we'll override when needed)
func._grid_interpolator = RegularGridInterpolator(
axes,
grid_data,
method="linear",
bounds_error=False,
fill_value=None, # Linear extrapolation (will be overridden by manual handling)
)

# Create placeholder domain and image for compatibility
# This flattens the grid for any code expecting these attributes
mesh = np.meshgrid(*axes, indexing="ij")
domain_points = np.column_stack([m.ravel() for m in mesh])
func._domain = domain_points
func._image = grid_data.ravel()

# Set source as flattened data array (for compatibility with serialization, etc.)
func.source = np.column_stack([domain_points, func._image])
Comment on lines +4222 to +4228
Copy link

Copilot AI Nov 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance consideration: Creating the full flattened domain and image arrays (lines 4185-4191) for compatibility purposes requires O(n^d) memory where n is the typical axis length and d is the number of dimensions. For large 3D grids (e.g., 100x100x100 points), this creates redundant data since the grid interpolator already has the structured data.

While this is necessary for backward compatibility with code expecting _domain and _image attributes, consider documenting this memory overhead and potentially adding a parameter to skip this flattening if compatibility with existing Function methods is not required.

Copilot uses AI. Check for mistakes.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aZira371 what do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm looking it this suggestion, will get back soon with some points on it.


# Initialize basic attributes
func.__inputs__ = inputs
func.__outputs__ = outputs if outputs is not None else "f"
func.__interpolation__ = interpolation
func.__extrapolation__ = extrapolation
func.title = kwargs.get("title", None)
func.__img_dim__ = 1
func.__cropped_domain__ = (None, None)
func._source_type = SourceType.ARRAY
func.__dom_dim__ = len(axes)

# Set basic array attributes for compatibility
func.x_array = axes[0]
func.x_initial, func.x_final = axes[0][0], axes[0][-1]
func.y_array = func._image[: len(axes[0])] # Placeholder
func.y_initial, func.y_final = func._image[0], func._image[-1]
if len(axes) > 2:
func.z_array = axes[2]
func.z_initial, func.z_final = axes[2][0], axes[2][-1]

# Set get_value_opt to use grid interpolation
func.get_value_opt = func.__get_value_opt_grid

# Set interpolation and extrapolation functions
func.__set_interpolation_func()
func.__set_extrapolation_func()

# Set inputs and outputs properly
func.set_inputs(inputs)
func.set_outputs(outputs)
func.set_title(func.title)

return func

@classmethod
def from_dict(cls, func_dict):
"""Creates a Function instance from a dictionary.
Expand Down
36 changes: 22 additions & 14 deletions rocketpy/rocket/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,20 +341,28 @@ def __init__( # pylint: disable=too-many-statements
)

# Define aerodynamic drag coefficients
self.power_off_drag = Function(
power_off_drag,
"Mach Number",
"Drag Coefficient with Power Off",
"linear",
"constant",
)
self.power_on_drag = Function(
power_on_drag,
"Mach Number",
"Drag Coefficient with Power On",
"linear",
"constant",
)
# If already a Function, use it directly (preserves multi-dimensional drag)
if isinstance(power_off_drag, Function):
self.power_off_drag = power_off_drag
else:
self.power_off_drag = Function(
power_off_drag,
"Mach Number",
"Drag Coefficient with Power Off",
"linear",
"constant",
)

if isinstance(power_on_drag, Function):
self.power_on_drag = power_on_drag
else:
self.power_on_drag = Function(
power_on_drag,
"Mach Number",
"Drag Coefficient with Power On",
"linear",
"constant",
)

# Create a, possibly, temporary empty motor
# self.motors = Components() # currently unused, only 1 motor is supported
Expand Down
Loading