Source code for metintos.io

# -*- coding: utf-8 -*-

import xarray as xr
import itertools
from metintos.optiflow import MultiFrameInterpolant
from metintos.utils import *
from math import ceil


[docs]def check_array_for_nans(arr): """ :param arr: :type arr: :returns: None """ assert not np.isnan(arr.sum())
[docs]class CoordinateGenerator(object): """Helper class for generating coordinates :var dict: Coordinates dictionary, mapping axis_name -> np.ndarray """
[docs] def __init__(self): """This declares the axes variable as a dictionary """ self.axes = {}
[docs] def add_axis_lims_resolution(self, ax_name, lower, upper, resolution): """Add an axis coordinate specifying limits and resolution (i.e. steps between coordinate points) :param ax_name: Name of the coordinate axis to be set :type ax_name: str :param lower: Lower limit of the coordinate axis :type lower: float :param upper: Upper limit of the coordinate axis :type upper: float :param resolution: Distance between axis points :type resolution: float :returns: None """ self.axes[ax_name] = np.arange(lower, upper, resolution)
[docs] def add_axis_lims_n_points(self, ax_name, lower, upper, n, dtype=None): """Add an axis coordinate specifying limits and number of points in coordinate axis :param ax_name: Name of the coordinate axis to be set :type ax_name: str :param lower: Lower limit of the coordinate axis :type lower: float :param upper: Upper limit of the coordinate axis :type upper: float :param n: Number of points in the coordinate axis :type n: int :param dtype: The type of the coordinate array. If dtype is not given, it will be inferred by numpy.linspace() :type dtype: dtype, optional :returns: None """ #self.axes[ax_name] = np.linspace(lower, upper, n, dtype=dtype) self.axes[ax_name] = lower + (upper - lower) * np.linspace(0, 1, n)
[docs]def get_var_name(dataset, variable): """ :param dataset: :type dataset: :param variable: :type variable: :returns: None """ try: return dataset[variable].shortname except AttributeError: return variable
[docs]class DatasetHandler(object): """Wrapper for a xarray.Dataset containing meteorological information at pressure levels :var ds: The dataset :type ds: xarray.Dataset """
[docs] def __init__(self, ds, output_dtype=np.float32, time_dim='step'): """ :param ds: The dataset :type ds: xarray.Dataset :param output_dtype: (optional) :type output_dtype: data-type :param time_dim: (optional) :type time_dim: str :returns: None """ self.time_dim = time_dim self.ds = ds self.output_dtype = np.float32 lat = self.ds.coords['latitude'] lon = self.ds.coords['longitude'] # check if we are under a dataset with ensemble self.is_ensemble = False try: self.is_ensemble = len(self.ds.number.values) except TypeError: pass if len(lat.shape) == 1: self.geo_grid = 'latlon' if lat[0] > lat[1]: self.ds = self.ds.reindex(latitude=lat[::-1]) # Monotonically ascending coordinates are required by some interpolants elif len(lat.shape) == 2: self.geo_grid = 'xy' if lat[0, 0] > lat[1, 0]: self.ds = self.ds.reindex(y=self.ds.coords['y'][::-1]) # Monotonically ascending coordinates are required by some interpolants else: raise ValueError(f"Cannot handle lat/lon coordinates with shape {lat.shape}") self.ds = self.ds.transpose(*self.default_dims)
[docs] @classmethod def load_from_steps(cls, path_list, **backend_kwargs): """Loads the meteo infromation from multiple grib files with consecutive forecast steps :param path_list: List of paths of the grib files :type path_list: list :param backend_kwargs: :type backend_kwargs: :returns: cls(ds) """ ds = xr.open_mfdataset(path_list, engine='cfgrib', concat_dim=['step'], combine='nested', **backend_kwargs) return cls(ds)
@property def default_dims(self): """Returns the default dimensions :returns: tuple with the dimensions. """ if self.geo_grid == 'latlon': dim_list = ('number', 'isobaricInhPa', self.time_dim, 'latitude', 'longitude') else: dim_list = ('number', 'isobaricInhPa', self.time_dim, 'y', 'x') return tuple(dim for dim in dim_list if dim in self.ds.dims.keys()) @ property def data_variables_list(self): """Return a list of variables that are not coordinates :returns: list of variables that are not coordinates """ all_vars = set(self.ds.variables.keys()) coords = set(self.ds.coords) data_vars = all_vars - coords return data_vars
[docs] def get_geographical_coordinate_slice_by_indexes(self, lat_idx_low, lat_idx_high, lon_idx_low, lon_idx_high): """Get a CoordinateGenerator (Not implemented) :param lat_idx_low: :type lat_idx_low: :param lat_idx_high: :type lat_idx_high: :param lon_idx_low: :type lon_idx_low: :param lon_idx_high: :type lon_idx_high: :returns: None """
[docs] def complete_new_coords(self, **coords): """Generates a coordinate set by completing the specificed new coordinates with the already existing ones :param coords: New coordinates, as a dict mapping variable names to monotonically increasing 1D arrays :type coords: dict :returns: New coordinates, completing missing dimensions with the already existing values :rtype: dict """ new_coords = {} for ax_name, ax in self.ds.coords.items(): new_coords[ax_name] = ax for ax_name, ax in coords.items(): new_coords[ax_name] = xr.IndexVariable(ax_name, ax, attrs=self.ds.coords[ax_name].attrs) try: del new_coords['valid_time'] except KeyError: pass return new_coords
[docs] def get_resampled_dataset(self, **coords): """Resamples the dataset to new coordinates :param coords: A dictionary of new coordinates, mapping coordinate names to numpy arrays. If a coordinate axis is missing, the old coordinate axis is used by default :type coords: dict :returns: A dataset reinterpolated to the new coordinates :rtype: xarray.Dataset """ data_vars = self.data_variables_list for dv in data_vars: # fake loop to pick any element from the data_vars set break dims = self.ds[dv].dims new_coords = self.complete_new_coords(**coords) new_shape = tuple(len(new_coords[d]) for d in dims) placeholder_dataarray = xr.DataArray(data=np.zeros(new_shape), dims=dims, coords=new_coords) return self.ds.interp_like(placeholder_dataarray)
[docs] def get_optical_flow_interpolated_dataset(self, new_coords, variables=None, flow_calculator=None, spatial_interp=None, verbose=False, zero_out_nans=True): """Reinterpolates the dataset to the new coordinates using optical flow interpolation :param new_coords: A dictionary of new coordinates, mapping coordinate names to numpy arrays. If a coordinate axis is missing, the old coordinate axis is used by default :type new_coords: dict :param variables: (optional) List of variables to be transported to the new dataset :type variables: list :param flow_calculator: (optional) Callable that computes the optical flow between two frames / arrays :type flow_calculator: (R² x R²) -> R² mapping :param spatial_interp: (optional) Callable that generates a 2D interpolant of a frame array with normalized coordinates :type spatial_interp: R² -> (R² -> R) :param verbose: (optional) If true, display the progress of the dataset reinterpolation procedure :type verbose: bool :param zero_out_nans: (optional) if True, replace any nans in the input with zeros :type zero_out_nans: bool :returns: The reinterpolated dataset :rtype: xarray.Dataset """ new_coords = self.complete_new_coords(**new_coords) if variables is None: variables = [get_var_name(self.ds, var) for var in self.ds.data_vars] new_vars = {var: self.get_optical_flow_interpolated_variable(new_coords, var, flow_calculator, spatial_interp, verbose=verbose, zero_out_nans=zero_out_nans) for var in variables} return xr.Dataset(data_vars=new_vars, coords=new_coords, attrs=self.ds.attrs)
[docs] @staticmethod def get_coords_shape(dims, coords): """Returns the shape of an array with the given coordinates and dimension ordering :param dims: dimension :type dims: :param coords: coordinates :type coords: :returns: The coordinates shape :rtype: tuple """ ll_dims = len(coords['latitude'].shape) if ll_dims == 1: return tuple(coords[d].size for d in dims) elif ll_dims == 2: geo_coords = ('y', 'x') return tuple(coords[d].size for d in dims if d not in geo_coords) + coords['latitude'].shape
[docs] def get_optical_flow_interpolated_variable(self, new_coords, variable, flow_calculator=None, spatial_interp=None, flow_padding=.1, verbose=False, zero_out_nans=True): """ :param new_coords: A dictionary of new coordinates, mapping coordinate names to numpy arrays. If a coordinate axis is missing, the old coordinate axis is used by default :type new_coords: dict :param variable: variable to be transported to the new dataset :type variable: str :param flow_calculator: (optional) Callable that computes the optical flow between two frames / arrays :type flow_calculator: (R² x R²) -> R² mapping :param spatial_interp: (optional) Callable that generates a 2D interpolant of a frame array with normalized coordinates :type spatial_interp: R² -> (R² -> R) :param flow_padding: (optional) :type flow_padding: float :param verbose: (optional) If true, display the progress of the dataset reinterpolation procedure :type verbose: bool :param zero_out_nans: (optional) if True, replace any nans in the input with zeros :type zero_out_nans: bool :returns: The reinterpolated dataset :rtype: xarray.Dataset """ reinterpolate = False reinterpolate_ens_level = False old_coords = self.ds.coords intermediate_coords = {} if self.time_dim in new_coords: intermediate_coords[self.time_dim] = new_coords[self.time_dim] limits = {'latitude': (-90.0, 90.0), 'longitude': (-180.0, 180.0)} geo_coords = limits.keys() for coord in geo_coords: try: assert np.allclose(new_coords[coord].values, old_coords[coord].values) except (ValueError, AssertionError): reinterpolate = self.geo_grid icg = CoordinateGenerator() if reinterpolate == "latlon": for coord in geo_coords: if coord in new_coords: axis_new = new_coords[coord].values axis_old = old_coords[coord].values resolution = axis_old[1] - axis_old[0] # assuming regular lat/lon grids span = axis_new[-1] - axis_new[0] padding = flow_padding * span padding_indexes = ceil(padding / resolution) low = max(axis_new[0] - padding_indexes * resolution, limits[coord][0], axis_old.min()) high = min(axis_new[-1] + (padding_indexes + 1) * resolution, limits[coord][1], axis_old.max()) icg.add_axis_lims_resolution(coord, low, high, resolution) intermediate_coords.update(icg.axes) reinterpolate_ens_level = True elif reinterpolate == "xy": raise NotImplementedError("Automatic clipping for optical flow calculation is not implemented for " "non-Plate-Carrée grids (i.e. 2D lat/lon grids). Regrid before or after" "optical flow resampling") else: icg.axes['latitude'] = old_coords['latitude'] icg.axes['longitude'] = old_coords['longitude'] for dim in self.default_dims: if dim not in (self.time_dim, ): if dim in new_coords.keys(): try: assert np.allclose(new_coords[dim].values, old_coords[dim].values) except (ValueError, AssertionError): reinterpolate_ens_level = True break intermediate_coords = self.complete_new_coords(**intermediate_coords) intermediate_shape = self.get_coords_shape(self.default_dims, intermediate_coords) intermediate_var_array = np.zeros(intermediate_shape, dtype=self.output_dtype) intermediate_data_array = xr.DataArray(data=intermediate_var_array, coords=intermediate_coords, dims=self.default_dims, name=variable, attrs=self.ds[variable].attrs) coords_to_iterate = {} if self.is_ensemble: coords_to_iterate_keys = ('number', 'isobaricInhPa') else: coords_to_iterate_keys = ('isobaricInhPa',) for dim_key in coords_to_iterate_keys: try: coords_to_iterate[dim_key] = list(new_coords[dim_key].values) except (KeyError, TypeError): coords_to_iterate[dim_key] = [None] iterate_combinations = itertools.product(*coords_to_iterate.values()) for coords_combi in iterate_combinations: if verbose: print(coords_combi) coords_selector_iter = {key: coords_combi[k] for k, key in enumerate(coords_to_iterate_keys) if coords_combi[k] is not None} coords_selector = coords_selector_iter.copy() if reinterpolate == 'latlon': coords_selector['latitude'] = icg.axes['latitude'] coords_selector['longitude'] = icg.axes['longitude'] def get_values(step, cs=coords_selector): _cs = cs.copy() _cs[self.time_dim] = step return self.ds[variable].sel(**_cs).values snapshots = [get_values(step) for step in self.ds.coords[self.time_dim]] if zero_out_nans: snapshots = list(map(np.nan_to_num, snapshots)) mfi = MultiFrameInterpolant(snapshots, flow_calculator, spatial_interp, old_coords[self.time_dim].values) for i, step in enumerate(new_coords[self.time_dim]): frame = mfi.eval_at_t(step.values) _csi = coords_selector_iter.copy() _csi[self.time_dim] = step intermediate_data_array.loc[dict(**_csi)] = frame if reinterpolate_ens_level: new_shape = self.get_coords_shape(self.default_dims, new_coords) new_var_array = np.zeros(new_shape, dtype=self.output_dtype) new_data_array = xr.DataArray(data=new_var_array, coords=new_coords, dims=self.default_dims, name=variable, attrs=self.ds[variable].attrs) return intermediate_data_array.interp_like(new_data_array) else: return intermediate_data_array