Source code for oct2py.io

# -*- coding: utf-8 -*-
# Copyright (c) oct2py developers.
# Distributed under the terms of the MIT License.

from __future__ import absolute_import, print_function, division

import inspect
import dis
import threading

import numpy as np
from scipy.io import loadmat, savemat
from scipy.io.matlab.mio5 import MatlabObject, MatlabFunction
from scipy.sparse import spmatrix

from .compat import PY2
from .dynamic import OctaveVariablePtr, OctaveUserClass, OctaveFunctionPtr
from .utils import Oct2PyError


_WRITE_LOCK = threading.Lock()


def read_file(path, session=None):
    """Read the data from the given file path.
    """
    try:
        data = loadmat(path, struct_as_record=True)
    except UnicodeDecodeError as e:
        raise Oct2PyError(str(e))
    out = dict()
    for (key, value) in data.items():
        out[key] = _extract(value, session)
    return out


def write_file(obj, path, oned_as='row', convert_to_float=True):
    """Save a Python object to an Octave file on the given path.
    """
    data = _encode(obj, convert_to_float)
    try:
        # scipy.io.savemat is not thread-save.
        # See https://github.com/scipy/scipy/issues/7260
        with _WRITE_LOCK:
            savemat(path, data, appendmat=False, oned_as=oned_as,
                    long_field_names=True)
    except KeyError:  # pragma: no cover
        raise Exception('could not save mat file')


[docs]class Struct(dict): """ Octave style struct, enhanced. Notes ===== Supports dictionary and attribute style access. Can be pickled, and supports code completion in a REPL. Examples ======== >>> from pprint import pprint >>> from oct2py import Struct >>> a = Struct() >>> a.b = 'spam' # a["b"] == 'spam' >>> a.c["d"] = 'eggs' # a.c.d == 'eggs' >>> pprint(a) {'b': 'spam', 'c': {'d': 'eggs'}} """ def __getattr__(self, attr): # Access the dictionary keys for unknown attributes. try: return self[attr] except KeyError: msg = "'Struct' object has no attribute %s" % attr raise AttributeError(msg) def __getitem__(self, attr): # Get a dict value; create a Struct if requesting a Struct member. # Do not create a key if the attribute starts with an underscore. if attr in self.keys() or attr.startswith('_'): return dict.__getitem__(self, attr) frame = inspect.currentframe() # step into the function that called us if frame.f_back.f_back and self._is_allowed(frame.f_back.f_back): dict.__setitem__(self, attr, Struct()) elif self._is_allowed(frame.f_back): dict.__setitem__(self, attr, Struct()) return dict.__getitem__(self, attr) def _is_allowed(self, frame): # Check for allowed op code in the calling frame. allowed = [dis.opmap['STORE_ATTR'], dis.opmap['LOAD_CONST'], dis.opmap.get('STOP_CODE', 0)] bytecode = frame.f_code.co_code instruction = bytecode[frame.f_lasti + 3] instruction = ord(instruction) if PY2 else instruction return instruction in allowed __setattr__ = dict.__setitem__ __delattr__ = dict.__delitem__ @property def __dict__(self): # Allow for code completion in a REPL. return self.copy()
[docs]class StructArray(np.recarray): """A Python representation of an Octave structure array. Notes ===== Accessing a record returns a Cell containing the values. This class is not meant to be directly created by the user. It is created automatically for structure array values received from Octave. The last axis is squeezed if it is of size 1 to simplify element access. Examples ======== >>> from oct2py import octave >>> # generate the struct array >>> octave.eval('x = struct("y", {1, 2}, "z", {3, 4});') >>> x = octave.pull('x') >>> x.y # attribute access -> oct2py Cell Cell([[1.0, 2.0]]) >>> x['z'] # item access -> oct2py Cell Cell([[3.0, 4.0]]) >>> x[0, 0] # index access -> numpy record (1.0, 3.0) >>> x[0, 1].z 4.0 """ def __new__(cls, value, session=None): """Create a struct array from a value and optional Octave session.""" value = np.asarray(value) # Squeeze the last element if it is 1 if (value.shape[value.ndim - 1] == 1): value = value.squeeze(axis=value.ndim - 1) value = np.atleast_1d(value) if not session: return value.view(cls) # Extract the values. obj = np.empty(value.size, dtype=value.dtype).view(cls) for (i, item) in enumerate(value.ravel()): for name in value.dtype.names: obj[i][name] = _extract(item[name], session) return obj.reshape(value.shape) @property def fieldnames(self): """The field names of the struct array.""" return self.dtype.names def __getattribute__(self, attr): """Return object arrays as cells and all other values unchanged. """ attr = np.recarray.__getattribute__(self, attr) if isinstance(attr, np.ndarray) and attr.dtype.kind == 'O': return Cell(attr) return attr def __getitem__(self, item): """Return object arrays as cells and all other values unchanged. """ item = np.recarray.__getitem__(self, item) if isinstance(item, np.ndarray) and item.dtype.kind == 'O': return Cell(item) return item def __repr__(self): shape = self.shape if len(shape) == 1: shape = (shape[0], 1) msg = 'x'.join(str(i) for i in shape) msg += ' StructArray containing the fields:' for key in self.fieldnames: msg += '\n %s' % key return msg
[docs]class Cell(np.ndarray): """A Python representation of an Octave cell array. Notes ===== This class is not meant to be directly created by the user. It is created automatically for cell array values received from Octave. The last axis is squeezed if it is of size 1 to simplify element access. Examples ======== >>> from oct2py import octave >>> # generate the struct array >>> octave.eval("x = cell(2,2); x(:) = 1.0;") >>> x = octave.pull('x') >>> x Cell([[1.0, 1.0], [1.0, 1.0]]) >>> x[0] Cell([1.0, 1.0]) >>> x[0].tolist() [1.0, 1.0] """ def __new__(cls, value, session=None): """Create a cell array from a value and optional Octave session.""" value = np.asarray(value, dtype=object) # Squeeze the last element if it is 1 if (value.shape[value.ndim - 1] == 1): value = value.squeeze(axis=value.ndim - 1) value = np.atleast_1d(value) if not session: return value.view(cls) # Extract the values. obj = np.empty(value.size, dtype=object).view(cls) for (i, item) in enumerate(value.ravel()): obj[i] = _extract(item, session) return obj.reshape(value.shape) def __repr__(self): shape = self.shape if len(shape) == 1: shape = (shape[0], 1) msg = self.view(np.ndarray).__repr__() msg = msg.replace('array', 'Cell', 1) return msg.replace(', dtype=object', '', 1)
def _extract(data, session=None): """Convert the Octave values to values suitable for Python. """ # Extract each item of a list. if isinstance(data, list): return [_extract(d, session) for d in data] # Ignore leaf objects. if not isinstance(data, np.ndarray): return data # Extract user defined classes. if isinstance(data, MatlabObject): cls = session._get_user_class(data.classname) return cls.from_value(data) # Extract struct data. if data.dtype.names: # Singular struct if data.size == 1: return _create_struct(data, session) # Struct array return StructArray(data, session) # Extract cells. if data.dtype.kind == 'O': return Cell(data, session) # Compress singleton values. if data.size == 1: return data.item() # Compress empty values. if data.size == 0: if data.dtype.kind in 'US': return '' return [] # Return standard array. return data def _create_struct(data, session): """Create a struct from session data. """ out = Struct() for name in data.dtype.names: item = data[name] # Extract values that are cells (they are doubly wrapped). if isinstance(item, np.ndarray) and item.dtype.kind == 'O': item = item.squeeze().tolist() out[name] = _extract(item, session) return out def _encode(data, convert_to_float): """Convert the Python values to values suitable to send to Octave. """ ctf = convert_to_float # Handle variable pointer. if isinstance(data, (OctaveVariablePtr)): return _encode(data.value, ctf) # Handle a user defined object. if isinstance(data, OctaveUserClass): return _encode(OctaveUserClass.to_value(data), ctf) # Handle a function pointer. if isinstance(data, (OctaveFunctionPtr, MatlabFunction)): raise Oct2PyError('Cannot write Octave functions') # Handle matlab objects. if isinstance(data, MatlabObject): view = data.view(np.ndarray) out = MatlabObject(data, data.classname) for name in out.dtype.names: out[name] = _encode(view[name], ctf) return out # Extract and encode values from dict-like objects. if isinstance(data, dict): out = dict() for (key, value) in data.items(): out[key] = _encode(value, ctf) return out # Send None as nan. if data is None: return np.NaN # Sets are treated like lists. if isinstance(data, set): return _encode(list(data), ctf) # Lists can be interpreted as numeric arrays or cell arrays. if isinstance(data, list): if _is_simple_numeric(data): return _encode(np.array(data), ctf) return _encode(tuple(data), ctf) # Tuples are handled as cells. if isinstance(data, tuple): obj = np.empty(len(data), dtype=object) for (i, item) in enumerate(data): obj[i] = _encode(item, ctf) return obj # Sparse data must be floating type. if isinstance(data, spmatrix): return data.astype(np.float64) # Return other data types unchanged. if not isinstance(data, np.ndarray): return data # Extract and encode data from object-like arrays. if data.dtype.kind in 'OV': out = np.empty(data.size, dtype=data.dtype) for (i, item) in enumerate(data.ravel()): if data.dtype.names: for name in data.dtype.names: out[i][name] = _encode(item[name], ctf) else: out[i] = _encode(item, ctf) return out.reshape(data.shape) # Complex 128 is the highest supported by savemat. if data.dtype.name == 'complex256': return data.astype(np.complex128) # Convert to float if applicable. if ctf and data.dtype.kind in 'ui': return data.astype(np.float64) # Return standard array. return data def _is_simple_numeric(data): """Test if a list contains simple numeric data.""" for item in data: if isinstance(item, set): item = list(item) if isinstance(item, list): if not _is_simple_numeric(item): return False elif not isinstance(item, (int, float, complex)): return False return True