''' The `graph` module contains tools for constructing and executing graphs
of pliers Transformers. '''
from itertools import chain
from collections import OrderedDict
import json
from pliers.extractors.base import merge_results
from pliers.stimuli import __all__ as stim_list
from pliers.transformers import get_transformer
from pliers.utils import (listify, flatten, isgenerator, attempt_to_import,
verify_dependencies)
pgv = attempt_to_import('pygraphviz', 'pgv')
stim_list.insert(0, 'ExtractorResult')
[docs]class Node:
''' A graph node/vertex. Represents a single transformer, optionally with
references to children.
Args:
name (str): Name of the node
transformer (Transformer): the Transformer instance at this node
parameters (kwargs): parameters for initializing the Transformer
'''
[docs] def __init__(self, transformer, name=None, **parameters):
self.name = name
self.children = []
if isinstance(transformer, str):
transformer = get_transformer(transformer, **parameters)
self.transformer = transformer
self.parameters = parameters
if name is not None:
self.transformer.name = name
self.id = id(transformer)
[docs] def add_child(self, node):
''' Append a child to the list of children. '''
self.children.append(node)
def is_leaf(self):
return len(self.children) == 0
def to_json(self):
spec = {'transformer': self.transformer.__class__.__name__}
if self.name:
spec['name'] = self.name
if self.children:
children = []
for c in self.children:
children.append(c.to_json())
spec['children'] = children
if self.parameters:
spec['parameters'] = self.parameters
return spec
[docs]class Graph:
''' Graph-like structure that represents an entire pliers workflow.
Args:
nodes (list, dict): Optional nodes to add to the Graph at construction.
If a dict, must have a 'roots' key. If a list, each element must be
in one of the forms accepted by add_nodes().
spec (str): An optional path to a .json file containing the graph
specification.
'''
[docs] def __init__(self, nodes=None, spec=None):
self.nodes = OrderedDict()
self.roots = []
if nodes is not None:
if isinstance(nodes, dict):
nodes = nodes['roots']
self.add_nodes(nodes)
elif spec is not None:
with open(spec) as spec_file:
self.add_nodes(json.load(spec_file)['roots'])
@staticmethod
def _parse_node_args(node):
if isinstance(node, dict):
return node
kwargs = {}
if isinstance(node, (list, tuple)):
kwargs['transformer'] = node[0]
if len(node) > 1:
kwargs['children'] = node[1]
if len(node) > 2:
kwargs['name'] = node[2]
elif isinstance(node, Node):
kwargs['transformer'] = node.transformer
kwargs['children'] = node.children
kwargs['name'] = node.name
else:
kwargs['transformer'] = node
return kwargs
[docs] def add_nodes(self, nodes, parent=None, mode='horizontal'):
''' Adds one or more nodes to the current graph.
Args:
nodes (list): A list of nodes to add. Each element must be one of
the following:
* A dict containing keyword args to pass onto to the Node init.
* An iterable containing 1 - 3 elements. The first element is
mandatory, and specifies the Transformer at that node. The
second element (optional) is an iterable of child nodes
(specified in the same format). The third element
(optional) is a string giving the (unique) name of the
node.
* A Node instance.
* A Transformer instance.
parent (Node): Optional parent node (i.e., the node containing the
pliers Transformer from which the to-be-created nodes receive
their inputs).
mode (str): Indicates the direction with which to add the new nodes
* horizontal: the nodes should each be added as a child of the
'parent' argument (or a Graph root by default).
* vertical: the nodes should each be added in sequence with
the first node being the child of the 'parnet' argument
(a Graph root by default) and each subsequent node being
the child of the previous node in the list.
'''
for n in nodes:
node_args = self._parse_node_args(n)
if mode == 'horizontal':
self.add_node(parent=parent, **node_args)
elif mode == 'vertical':
parent = self.add_node(parent=parent, return_node=True,
**node_args)
else:
raise ValueError("Invalid mode for adding nodes to a graph:"
"%s" % mode)
[docs] def add_chain(self, nodes, parent=None):
''' An alias for add_nodes with the mode preset to 'vertical'. '''
self.add_nodes(nodes, parent, 'vertical')
[docs] def add_children(self, nodes, parent=None):
''' An alias for add_nodes with the mode preset to 'horizontal'. '''
self.add_nodes(nodes, parent, 'horizontal')
[docs] def add_node(self, transformer, name=None, children=None, parent=None,
parameters={}, return_node=False):
''' Adds a node to the current graph.
Args:
transformer (str, Transformer): The pliers Transformer to use at
the to-be-added node. Either a case-insensitive string giving
the name of a Transformer class, or an initialized Transformer
instance.
name (str): Optional name to give this Node.
children (list): Optional list of child nodes (i.e., nodes to pass
the to-be-added node's Transformer output to).
parent (Node): Optional node from which the to-be-added Node
receives its input.
parameters (dict): Optional keyword arguments to pass onto the
Transformer initialized at this Node if a string is passed to
the 'transformer' argument. Ignored if an already-initialized
Transformer is passed.
return_node (bool): If True, returns the initialized Node instance.
Returns:
The initialized Node instance if return_node is True,
None otherwise.
'''
node = Node(transformer, name, **parameters)
self.nodes[node.id] = node
if parent is None:
self.roots.append(node)
else:
parent = self.nodes[parent.id]
parent.add_child(node)
if children is not None:
self.add_nodes(children, parent=node)
if return_node:
return node
[docs] def run(self, stim, merge=True, **merge_kwargs):
''' Executes the graph by calling all Transformers in sequence.
Args:
stim (str, Stim, list): One or more valid inputs to any
Transformer's 'transform' call.
merge (bool): If True, all results are merged into a single pandas
DataFrame before being returned. If False, a list of
ExtractorResult objects is returned (one per Extractor/Stim
combination).
merge_kwargs: Optional keyword arguments to pass onto the
merge_results() call.
'''
results = list(chain(*[self.run_node(n, stim) for n in self.roots]))
results = list(flatten(results))
self._results = results # For use in plotting
return merge_results(results, **merge_kwargs) if merge else results
transform = run
[docs] def run_node(self, node, stim):
''' Executes the Transformer at a specific node.
Args:
node (str, Node): If a string, the name of the Node in the current
Graph. Otherwise the Node instance to execute.
stim (str, stim, list): Any valid input to the Transformer stored
at the target node.
'''
if isinstance(node, str):
node = self.nodes[node]
result = node.transformer.transform(stim)
if node.is_leaf():
return listify(result)
stim = result
# If result is a generator, the first child will destroy the
# iterable, so cache via list conversion
if len(node.children) > 1 and isgenerator(stim):
stim = list(stim)
return list(chain(*[self.run_node(c, stim) for c in node.children]))
[docs] def draw(self, filename, color=True):
''' Render a plot of the graph via pygraphviz.
Args:
filename (str): Path to save the generated image to.
color (bool): If True, will color graph nodes based on their type,
otherwise will draw a black-and-white graph.
'''
verify_dependencies(['pgv'])
if not hasattr(self, '_results'):
raise RuntimeError("Graph cannot be drawn before it is executed. "
"Try calling run() first.")
g = pgv.AGraph(directed=True)
g.node_attr['colorscheme'] = 'set312'
for elem in self._results:
if not hasattr(elem, 'history'):
continue
log = elem.history
while log:
# Configure nodes
source_from = log.parent[6] if log.parent else ''
s_node = hash((source_from, log[2]))
s_color = stim_list.index(log[2])
s_color = s_color % 12 + 1
t_node = hash((log[6], log[7]))
t_style = 'filled,' if color else ''
t_style += 'dotted' if log.implicit else ''
if log[6].endswith('Extractor'):
t_color = '#0082c8'
elif log[6].endswith('Filter'):
t_color = '#e6194b'
else:
t_color = '#3cb44b'
r_node = hash((log[6], log[5]))
r_color = stim_list.index(log[5])
r_color = r_color % 12 + 1
# Add nodes
if color:
g.add_node(s_node, label=log[2], shape='ellipse',
style='filled', fillcolor=s_color)
g.add_node(t_node, label=log[6], shape='box',
style=t_style, fillcolor=t_color)
g.add_node(r_node, label=log[5], shape='ellipse',
style='filled', fillcolor=r_color)
else:
g.add_node(s_node, label=log[2], shape='ellipse')
g.add_node(t_node, label=log[6], shape='box',
style=t_style)
g.add_node(r_node, label=log[5], shape='ellipse')
# Add edges
g.add_edge(s_node, t_node, style=t_style)
g.add_edge(t_node, r_node, style=t_style)
log = log.parent
g.draw(filename, prog='dot')
[docs] def to_json(self):
''' Returns the JSON representation of this graph. '''
roots = []
for r in self.roots:
roots.append(r.to_json())
return {'roots': roots}
[docs] def save(self, filename):
''' Writes the JSON representation of this graph to the provided
filename, such that the graph can be easily reconstructed using
Graph(spec=filename).
Args:
filename (str): Path at which to write out the json file.
'''
with open(filename, 'w') as outfile:
json.dump(self.to_json(), outfile)