Source code for gunagala.utils

# Licensed under a 3-clause BSD style license - see LICENSE.rst

# This sub-module is destined for common non-package specific utility
# functions that will ultimately be merged into `astropy.utils`
import os
import functools
import numpy as np
import astropy.units as u
from astropy.table import Table
from astropy.utils.data import get_pkg_data_filename


[docs]def ensure_unit(arg, unit): """ Ensures that the argument has the requested units, performing conversions as necessary. Parameters ---------- arg : astropy.units.Quantity or compatible Argument to be coerced into the requested units. Can be an `astropy.units.Quantity` instance or any numeric type or sequence that is compatible with the `Quantity` constructor (e.g. a `numpy.array`, `list` of `float`, etc.). unit : astropy.units.Unit Requested units. Returns ------- arg : astropy.units.Quantity `arg` as an `astropy.units.Quantity` with units of `unit`. """ try: arg = arg.to(unit) except u.UnitConversionError as err: # arg is a Quantity or compatible class, but the units are incompatible. raise err except: # Some other exception means arg isn't a Quantity or compatible. Try converting it. arg = arg * unit return arg
[docs]def get_table_data(data_table, column_names, column_units, data_dir='data/performance_data', **kwargs): """ Parses a data table to extract specified columns, converted to Quantity with specified units. Parameters ---------- data_table: astropy.table.Table or str The data table for parsing, either as an astropy.table.Table object or the name of a file that can be read by `astropy.table.Table.read()`. The filename can be either the path to a user file or the name of one of gunagala's included files. column_names: sequence Names of the columns to extract from the table column_units: sequence Desired units for the extracted columns. If data_table specifies units for its columns then the extracted columns will be converted to these units. If not then the specified units will be added to the corresponding column. Additional keyword arguments will be passed to the call to astropy.table.Table.read() if reading a Table from a file. See the documentation for Table.read() for details of the available parameters. Returns ------- data: tuple of astropy.units.Quantity Tuple of Quantity objects corresponding to the named columns, with the specified units. """ if not isinstance(data_table, Table): # data_table isn't a Table, assume it's a filename. if not os.path.exists(data_table): # Not a (valid) path to a user file, look in package data directories try: data_table = get_pkg_data_filename(os.path.join(data_dir, data_table), package='gunagala') except: # Not in package data directories either raise IOError("Couldn't find data table {}!".format(data_table)) data_table = Table.read(data_table, **kwargs) data = [] for name, unit in zip(column_names, column_units): try: column = data_table[name] except KeyError: raise ValueError("Data table has no column named {}!".format(name)) if not column.unit: column.unit = unit data.append(column.quantity) else: data.append(column.quantity.to(unit)) return data
[docs]def array_sequence_equal(array_sequence, reference=None): """ Determine if all array objects in a sequence are equal. Parameters ---------- array_sequence: sequence of numpy.array Sequence of numpy.array or compatible type (e.g. astropy.units.Quantity) objects to compare. The objects must support element-wise comparision and implement an any() method. reference: numpy.array, optional If given all arrays in the sequence will be compared with reference, otherwise they will be compared with each other. Must be a numpy.array or compatible type (e.g. Quantity). Returns ------- equal: bool True if all arrays in the sequence are equal (or equal to reference, if given), otherwise False. """ if len(array_sequence) == 0: raise ValueError('array_sequence must contain at least one array object!') else: if reference is None: reference = array_sequence[0] for array in array_sequence: try: assert (array == reference).all() except (AttributeError, AssertionError): # Attribute error if array & reference different lengths, Assertion Error if same # length but one or more elements differ in value. return False return True
[docs]def bin_array(data, binning_factor, bin_func=np.sum): """ Bin 2D array data by a given factor using a given binning function. Parameters ---------- data: numpy.array Array to be binned. binning_factor: int Size of the binning regions, in pixels. bin_func: function, optional Function to be used to combine the pixel values within each binning region. The function must accept a numpy.array as the first argument, and accept an axis keyword argument to specify which array axis to peform the combination on. Default numpy.sum(). Returns ------- binned: numpy.array Binned array. """ shape = (data.shape[0]//binning_factor, binning_factor, data.shape[1]//binning_factor, binning_factor) return bin_func(bin_func(data.reshape(shape), axis=3), axis=1)