Source code for pliers.transformers.base

''' Core transformer logic. '''
from abc import ABCMeta, abstractmethod, abstractproperty
import importlib
import logging
from functools import wraps

from pliers import config
from pliers.stimuli.base import Stim, _log_transformation, load_stims
from pliers.stimuli.compound import CompoundStim
from pliers.utils import (progress_bar_wrapper, isiterable,
                          isgenerator, listify, batch_iterable,
                          attempt_to_import, set_iterable_type)
import pliers


multiprocessing = attempt_to_import('pathos.multiprocessing',
                                    'multiprocessing', ['ProcessingPool'])

_cache = {}


[docs]class Transformer(metaclass=ABCMeta): ''' Base class for all pliers Transformers. Args: name (str): Optional name of Transformer instance. If None (default), the class name is used. ''' _log_attributes = () _loggable = True VERSION = '0.1' # Stim types that *can* be passed as input, but aren't mandatory. This # allows for disjunctive specification; e.g., if _input_type is empty # and _optional_input_type is (AudioStim, TextStim), then _at least_ one # of the two must be passed. If both are specified in _input_type, then # the input would have to be a CompoundStim with both audio and text slots. _optional_input_type = ()
[docs] def __init__(self, name=None, **kwargs): if name is None: name = self.__class__.__name__ self.name = name super().__init__(**kwargs)
def _memoize(transform): @wraps(transform) def wrapper(self, stim, *args, **kwargs): use_cache = config.get_option('cache_transformers') \ and isinstance(stim, (Stim, str)) if use_cache: key = hash((hash(self), hash(stim))) if key in _cache: return _cache[key] result = transform(self, stim, *args, **kwargs) if use_cache: if isgenerator(result): result = list(result) _cache[key] = result return result return wrapper
[docs] @_memoize def transform(self, stims, validation='strict', *args, **kwargs): ''' Executes the transformation on the passed stim(s). Args: stims (str, Stim, list): One or more stimuli to process. Must be one of: - A string giving the path to a file that can be read in as a Stim (e.g., a .txt file, .jpg image, etc.) - A Stim instance of any type. - An iterable of stims, where each element is either a string or a Stim. validation (str): String specifying how validation errors should be handled. Must be one of: - 'strict': Raise an exception on any validation error - 'warn': Issue a warning for all validation errors - 'loose': Silently ignore all validation errors args: Optional positional arguments to pass onto the internal _transform call. kwargs: Optional positional arguments to pass onto the internal _transform call. ''' if isinstance(stims, str): stims = load_stims(stims) # If stims is a CompoundStim and the Transformer is expecting a single # input type, extract all matching stims if isinstance(stims, CompoundStim) and not isinstance(self._input_type, tuple): stims = stims.get_stim(self._input_type, return_all=True) if not stims: raise ValueError("No stims of class %s found in the provided" "CompoundStim instance." % self._input_type) # If stims is an iterable, naively loop over elements, removing # invalid results if needed if isiterable(stims): iters = self._iterate(stims, validation=validation, *args, **kwargs) if config.get_option('drop_bad_extractor_results'): iters = (i for i in iters if i is not None) iters = progress_bar_wrapper(iters, desc='Stim') return set_iterable_type(iters) # Validate stim, and then either pass it directly to the Transformer # or, if a conversion occurred, recurse. else: try: validated_stim = self._validate(stims) except TypeError as err: if validation == 'strict': raise err elif validation == 'warn': logging.warning(str(err)) return elif validation == 'loose': return # If a conversion occurred during validation, we recurse if stims is not validated_stim: return self.transform(validated_stim, *args, **kwargs) else: result = self._transform(validated_stim, *args, **kwargs) result = _log_transformation(validated_stim, result, self) if isgenerator(result): result = list(result) self._propagate_context(validated_stim, result) return result
def _validate(self, stim): # Checks whether the current Transformer can handle the passed Stim. # If not, attempts a dynamic conversion before failing. if not self._stim_matches_input_types(stim): from pliers.converters.base import get_converter in_type = self._input_type if self._input_type \ else self._optional_input_type converter = get_converter(type(stim), in_type) if converter: _old_stim = stim stim = converter.transform(stim) stim = _log_transformation(_old_stim, stim, converter, True) else: msg = ("Transformers of type %s can only be applied to stimuli" " of type(s) %s (not type %s), and no applicable " "Converter was found.") msg = msg % (self.__class__.__name__, in_type, stim.__class__.__name__) raise TypeError(msg) return stim def _stim_matches_input_types(self, stim): # Checks if passed Stim meets all _input_type and _optional_input_type # specifications. mandatory = tuple(listify(self._input_type)) optional = tuple(listify(self._optional_input_type)) if isinstance(stim, CompoundStim): return stim.has_types(mandatory) or \ (not mandatory and stim.has_types(optional, False)) if len(mandatory) > 1: msg = ("Transformer of class %s requires multiple mandatory " "inputs, so the passed input Stim must be a CompoundStim" "--which it isn't." % self.__class__.__name__) logging.warning(msg) return False return isinstance(stim, mandatory) or (not mandatory and isinstance(stim, optional)) def _iterate(self, stims, *args, **kwargs): if config.get_option('parallelize') and multiprocessing is not None: def _transform(s): return self.transform(s, *args, **kwargs) n_jobs = config.get_option('n_jobs') return multiprocessing.ProcessingPool(n_jobs) \ .map(_transform, stims) return (t for t in (self.transform(s, *args, **kwargs) for s in stims) if t) def _propagate_context(self, stim, result): if isiterable(result): for r in result: self._propagate_context(stim, r) else: if result.onset is None: result.onset = stim.onset if result.duration is None: result.duration = stim.duration if result.order is None: result.order = stim.order @abstractmethod def _transform(self, stim): pass @abstractproperty def _input_type(self): pass def __hash__(self): tr_attrs = [getattr(self, attr) for attr in self._log_attributes] return hash(self.name + str(dict(zip(self._log_attributes, tr_attrs))))
[docs]class BatchTransformerMixin(Transformer): ''' A mixin that overrides the default implicit iteration behavior. Use whenever batch processing of multiple stimuli should be handled within the _transform method rather than applying a naive loop--e.g., for API Extractors that can handle list inputs. Args: batch_size (int): Number of Stims to process in each batch. args, kwargs: Optional positional and keyword arguments to pass onto the base Transformer initializer. '''
[docs] def __init__(self, batch_size=None, *args, **kwargs): if batch_size: self._batch_size = batch_size super().__init__(*args, **kwargs)
def _iterate(self, stims, validation='strict', *args, **kwargs): batches = batch_iterable(stims, self._batch_size) results = [] for batch in progress_bar_wrapper(batches): use_cache = config.get_option('cache_transformers') target_inds = {} non_cached = [] for stim in batch: key = hash((hash(self), hash(stim))) # If using the cache, only transform stims that aren't in the # cache and haven't already appeared in the batch if not (use_cache and (key in _cache or key in target_inds)): target_inds[key] = len(non_cached) non_cached.append(stim) # _transform will likely fail if given an empty list if len(non_cached) > 0: batch_results = self._transform(non_cached, *args, **kwargs) else: batch_results = [] for i, stim in enumerate(batch): key = hash((hash(self), hash(stim))) # Use the target index to get the result from batch_results if key in target_inds: result = batch_results[target_inds[key]] result = _log_transformation(stim, result, self) self._propagate_context(stim, result) if use_cache: if isgenerator(result): result = list(result) _cache[key] = result results.append(result) # Otherwise, the result should be in the cache else: results.append(_cache[key]) return results def _transform(self, stim, *args, **kwargs): stims = listify(stim) if all(self._stim_matches_input_types(s) for s in stims): result = super() \ ._transform(stims, *args, **kwargs) if isiterable(stim): return result else: return result[0] else: return list(super() ._iterate(stims, *args, **kwargs))
[docs]def get_transformer(name, base=None, *args, **kwargs): ''' Scans list of currently available Transformer classes and returns an instantiation of the first one whose name perfectly matches (case-insensitive). Args: name (str): The name of the transformer to retrieve. Case-insensitive; e.g., 'stftextractor' or 'CornerDetectionExtractor'. base (str, list): Optional name of transformer modules to search. Valid values are 'converters', 'extractors', and 'filters'. args, kwargs: Optional positional or keyword arguments to pass onto the Transformer. ''' name = name.lower() # Default to searching all kinds of Transformers if base is None: base = ['extractors', 'converters', 'filters'] base = listify(base) for b in base: importlib.import_module('pliers.%s' % b) mod = getattr(pliers, b) classes = getattr(mod, '__all__') for cls_name in classes: if cls_name.lower() == name.lower(): cls = getattr(mod, cls_name) return cls(*args, **kwargs) raise KeyError("No transformer named '%s' found." % name)