import numpy as np
import pyart
import os
import gc
import glob
from distributed import Client, wait
from scipy.interpolate import griddata
from copy import deepcopy
from datetime import datetime
from .wind_retrieve import get_dd_wind_field
# Reduces the resolution of a PyART grid
def _reduce_pyart_grid_res(Grid, skip_factor):
Grid2 = deepcopy(Grid)
field_dict = {}
for field_name in Grid2.fields.keys():
field_dict[field_name] = Grid2.fields[field_name].copy()
field_dict[field_name]["data"] = Grid2.fields[field_name]["data"][
:, ::skip_factor, ::skip_factor]
x = Grid2.x
x["data"] = x["data"][::skip_factor]
y = Grid2.y
y["data"] = y["data"][::skip_factor]
z = Grid2.z
z["data"] = z["data"]
metadata = Grid2.metadata
origin_latitude = Grid2.origin_latitude
origin_longitude = Grid2.origin_longitude
origin_altitude = Grid2.origin_altitude
projection = Grid2.projection
radar_latitude = Grid2.radar_latitude
radar_longitude = Grid2.radar_longitude
radar_altitude = Grid2.radar_altitude
radar_time = Grid2.radar_time
radar_name = Grid2.radar_name
gtime = Grid2.time
new_grid = pyart.core.Grid(
gtime, field_dict, metadata, origin_latitude, origin_longitude,
origin_altitude, x, y, z, projection, radar_latitude, radar_longitude,
radar_altitude, radar_time, radar_name)
del Grid2
return new_grid
# Splits a Py-ART Grid
def _split_pyart_grid(Grid, split_factor, axis=1):
grid_splits = []
split_field = {}
Grid2 = deepcopy(Grid)
for field_name in Grid2.fields.keys():
if isinstance(Grid2.fields[field_name]["data"], np.ma.MaskedArray):
no_mask = Grid2.fields[field_name]["data"].filled(np.nan).copy()
else:
no_mask = Grid2.fields[field_name]["data"].copy()
split_field[field_name] = np.array_split(
no_mask, split_factor, axis=axis)
if isinstance(Grid2.fields[field_name]["data"], np.ma.MaskedArray):
split_field[field_name] = [np.ma.masked_where(
np.isnan(arr), arr) for arr in split_field[field_name]]
x = Grid2.x
y = Grid2.y
z = Grid2.z
x_split = np.array_split(x["data"], split_factor)
y_split = np.array_split(y["data"], split_factor)
z_split = np.array_split(z["data"], split_factor)
gtime = Grid2.time
metadata = Grid2.metadata
origin_latitude = Grid2.origin_latitude
origin_longitude = Grid2.origin_longitude
origin_altitude = Grid2.origin_altitude
projection = Grid2.projection
radar_latitude = Grid2.radar_latitude
radar_longitude = Grid2.radar_longitude
radar_altitude = Grid2.radar_altitude
radar_time = Grid2.radar_time
radar_name = Grid2.radar_name
for i in range(split_factor):
grid_dic = {}
for field_name in Grid2.fields.keys():
grid_dic[field_name] = Grid2.fields[field_name].copy()
grid_dic[field_name]["data"] = split_field[field_name][i]
x_dic = x.copy()
y_dic = y.copy()
z_dic = z.copy()
if(axis == 1):
y_dic["data"] = y_split[i]
elif(axis == 2):
x_dic["data"] = x_split[i]
elif(axis == 0):
z_dic["data"] = z_split[i]
new_grid = pyart.core.Grid(
gtime, grid_dic, metadata, origin_latitude, origin_longitude,
origin_altitude, x_dic, y_dic, z_dic, projection, radar_latitude,
radar_longitude, radar_altitude, radar_time, radar_name)
grid_splits.append(new_grid)
return grid_splits
# Concatenates Py-ART Grids
def _concatenate_pyart_grids(grid_list, axis=1):
new_grid = deepcopy(grid_list[0])
for field_name in new_grid.fields.keys():
new_grid.fields[field_name]["data"] = np.ma.concatenate(
[x.fields[field_name]["data"] for x in grid_list], axis=axis)
if(axis == 2):
new_grid.x["data"] = np.ma.concatenate(
[x.x["data"] for x in grid_list])
new_grid.nx = np.sum([x.nx for x in grid_list])
elif(axis == 1):
new_grid.y["data"] = np.ma.concatenate(
[x.y["data"] for x in grid_list])
new_grid.ny = np.sum([x.ny for x in grid_list])
elif(axis == 0):
new_grid.z["data"] = np.ma.concatenate(
[x.z["data"] for x in grid_list])
new_grid.nz = np.sum([x.nz for x in grid_list])
return new_grid
# Procedure: 1. Do first pass of retrieval on reduced resolution grid
# 2. Then, we use the reduced resolution retrieval as an input to the
# high resolution retrieval in each region
# Finally, we check for continuity at the boundaries
[docs]def get_dd_wind_field_nested(grid_list, u_init, v_init, w_init, client,
reduction_factor=2, num_splits=2, **kwargs):
"""
This function performs a wind retrieval using a nested domain.
This is useful for grids that are larger than about 500 by 500
by 40 points, since the use of larger grids on a single machine
will exceed memory limitations.
This procedure relies on a dask distributed cluster to be set up.
The retrieval is first performed at a resolution that is coarser
than the analysis grid by reduction_factor. This provides the
initial state for the nested loop.
The domain is split into num_splits**2 sub-domains for the nested
retrieval step, and each nested retrieval is mapped onto a distributed
worker for parallel processing. If NumPy and SciPy are already set up to
use parallel numerical analysis libraries, it is recommended that a single
machine be dedicated to each nest rather than a single core
for best performance.
Parameters
==========
grid_list: list
A list of Py-ART grids for each radar to use in the retrieval.
u_init: 3D NumPy array
The initial guess of the zonal wind field. This has to be in the same
shape as the analysis grid.
v_init: 3D NumPy array
The initial guess of the meridional wind field. This has to be in the
same shape as the analysis grid.
w_init: 3D NumPy array
The initial guess of the vertical wind field. This has to be in the same
shape as the analysis grid.
client: dask distributed Client
The distributed Client that is linked to a distributed cluster. The
:cluster must be running before get_dd_wind_field_nested is called.
The retrieval on each nest will be mapped onto each worker. Since
the optimization loop already takes advantage of parallelism, it's
best to allow at least 16 cores per one worker.
reduction_factor: int
How much to reduce the factor of the analysis grid by when doing the
initial retrieval on the entire grid.
num_splits: int
The number of splits to make through each axis when doing the nesting.
**kwargs: dict
This function will take the same keyword arguments as
get_dd_wind_field, as these arguments are passed into each call of
get_dd_wind_field. See get_dd_wind_field for more information on the
"""
# First, we do retrieval on whole grid with fraction of resolution
grid_lo_res_list = [_reduce_pyart_grid_res(G, reduction_factor)
for G in grid_list]
first_pass = get_dd_wind_field(
grid_lo_res_list, u_init[::, ::reduction_factor, ::reduction_factor],
v_init[::, ::reduction_factor, ::reduction_factor],
w_init[::, ::reduction_factor, ::reduction_factor], **kwargs)
# Take the first pass field and regrid to analysis field
reduced_x = first_pass[0].point_x["data"].flatten()
reduced_y = first_pass[0].point_y["data"].flatten()
reduced_z = first_pass[0].point_z["data"].flatten()
x = grid_list[0].point_x["data"].flatten()
y = grid_list[0].point_y["data"].flatten()
z = grid_list[0].point_z["data"].flatten()
u_init_new = griddata((reduced_z, reduced_y, reduced_x),
first_pass[0].fields["u"]["data"].flatten(),
(z, y, x), method='nearest')
v_init_new = griddata((reduced_z, reduced_y, reduced_x),
first_pass[0].fields["v"]["data"].flatten(),
(z, y, x), method='nearest')
w_init_new = griddata((reduced_z, reduced_y, reduced_x),
first_pass[0].fields["w"]["data"].flatten(),
(z, y, x), method='nearest')
u_init_new = np.reshape(u_init_new, u_init.shape)
v_init_new = np.reshape(v_init_new, v_init.shape)
w_init_new = np.reshape(w_init_new, w_init.shape)
# Finally, split the analysis into num_splits**2 pieces and save
# as temporary files
tempfile_name_base = datetime.now().strftime('%y%m%d.%H%M%S')
tiny_grids = []
k = 0
for G in grid_list:
cur_list = []
split_grids_x = _split_pyart_grid(G, num_splits, axis=2)
i = 0
for sgrid in split_grids_x:
g_list = _split_pyart_grid(sgrid, num_splits)
grid_fns = []
j = 0
for g in g_list:
fn = (tempfile_name_base + str(k) + '.' +
str(i) + '.' + str(j) + '.nc')
pyart.io.write_grid((tempfile_name_base + str(k) +
'.' + str(i) + '.' + str(j) + '.nc'), g)
j = j + 1
grid_fns.append(fn)
cur_list.append(grid_fns)
i = i + 1
del split_grids_x, g_list
k = k + 1
tiny_grids.append(cur_list)
# Temporarily save the tiny grids and free up memory...we want to
# load these when we are running it on the cluster
u_init_split_x = np.array_split(u_init_new, num_splits, axis=2)
u_init_split = [np.array_split(ux, num_splits, axis=1)
for ux in u_init_split_x]
w_init_split_x = np.array_split(w_init_new, num_splits, axis=2)
w_init_split = [np.array_split(wx, num_splits, axis=1)
for wx in w_init_split_x]
v_init_split_x = np.array_split(v_init_new, num_splits, axis=2)
v_init_split = [np.array_split(vx, num_splits, axis=1)
for vx in v_init_split_x]
# Clear out unneeded variables (do not need lo-res grids in memory anymore)
del u_init_split_x, w_init_split_x, v_init_split_x
del first_pass, reduced_x, reduced_y, reduced_z, x, y, z, grid_lo_res_list
gc.collect()
# Serial just for testing, need to use dask in future
tiny_retrieval = []
def do_tiny_retrieval(i, j):
tgrids = [pyart.io.read_grid(tiny_grids[k][i][j])
for k in range(len(grid_list))]
new_grids = get_dd_wind_field(
tgrids, u_init_split[i][j], v_init_split[i][j],
w_init_split[i][j], **kwargs)
del tgrids
gc.collect()
return new_grids
futures_array = []
for i in range(num_splits):
for j in range(num_splits):
futures_array.append(client.submit(do_tiny_retrieval, i, j))
print("Waiting for nested grid to be retrieved...")
wait(futures_array)
tiny_retrieval2 = client.gather(futures_array)
tiny_retrieval = []
for i in range(num_splits):
new_grid_list = []
for j in range(len(grid_list)):
new_grid_list.append(_concatenate_pyart_grids(
[tiny_retrieval2[k+i*num_splits][j]
for k in range(0, num_splits)], axis=1))
tiny_retrieval.append(new_grid_list)
new_grid_list = []
for i in range(len(grid_list)):
new_grid_list.append(_concatenate_pyart_grids(
[tiny_retrieval[k][i] for k in range(num_splits)], axis=2))
tempfile_list = glob.glob(tempfile_name_base + "*")
for fn in tempfile_list:
os.remove(fn)
return new_grid_list