Skip to content

MRG, MAINT: Simpler vector params #291

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ jobs:
command: |
python -m pip install --user -q --upgrade pip numpy
python -m pip install --user -q --upgrade --progress-bar off scipy matplotlib vtk pyqt5 pyqt5-sip nibabel sphinx numpydoc pillow imageio imageio-ffmpeg sphinx-gallery
python -m pip install --user -q --upgrade mayavi "https://api.github.com/repos/mne-tools/mne-python/zipball/master"
python -m pip install --user -q --upgrade mayavi "https://github.com/mne-tools/mne-python/archive/master.zip"
- save_cache:
key: pip-cache
paths:
Expand Down
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ before_install:
pip install https://github.com/enthought/mayavi/zipball/master;
fi;
- mkdir -p $SUBJECTS_DIR
- pip install "https://api.github.com/repos/mne-tools/mne-python/zipball/master";
- pip install "https://github.com/mne-tools/mne-python/archive/master.zip"
- python -c "import mne; mne.datasets.fetch_fsaverage(verbose=True)"

install:
Expand Down
46 changes: 19 additions & 27 deletions surfer/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,7 @@ def __init__(self, subject_id, hemi, surf, title=None,
title = subject_id
self.subject_id = subject_id

if not isinstance(views, list):
if not isinstance(views, (list, tuple)):
views = [views]
n_row = len(views)

Expand Down Expand Up @@ -1095,23 +1095,20 @@ def add_data(self, array, min=None, max=None, thresh=None,
smooth_mat = None

magnitude = None
magnitude_max = None
if array.ndim == 3:
if array.shape[1] != 3:
raise ValueError('If array has 3 dimensions, array.shape[1] '
'must equal 3, got %s' % (array.shape[1],))
magnitude = np.linalg.norm(array, axis=1)
if scale_factor is None:
distance = np.sum([array[:, dim, :].ptp(axis=0).max() ** 2
for dim in range(3)])
distance = 4 * np.linalg.norm(array, axis=1).max()
if distance == 0:
scale_factor = 1
else:
scale_factor = (0.4 * distance /
(4 * array.shape[0] ** (0.33)))
if self._units == 'm':
scale_factor = scale_factor / 1000.
magnitude_max = magnitude.max()
elif array.ndim not in (1, 2):
raise ValueError('array has must have 1, 2, or 3 dimensions, '
'got (%s)' % (array.ndim,))
Expand Down Expand Up @@ -1188,7 +1185,7 @@ def time_label(x):
if brain['hemi'] == hemi:
s, ct, bar, gl = brain['brain'].add_data(
array, min, mid, max, thresh, lut, colormap, alpha,
colorbar, layer_id, smooth_mat, magnitude, magnitude_max,
colorbar, layer_id, smooth_mat, magnitude,
scale_factor, vertices, vector_alpha, **kwargs)
surfs.append(s)
bars.append(bar)
Expand Down Expand Up @@ -2115,13 +2112,11 @@ def set_data_time_index(self, time_idx, interpolation='quadratic'):
if vectors is not None:
vectors = vectors[:, :, time_idx]

vector_values = scalar_data.copy()
if data['smooth_mat'] is not None:
scalar_data = data['smooth_mat'] * scalar_data
for brain in self.brains:
if brain.hemi == hemi:
brain.set_data(data['layer_id'], scalar_data,
vectors, vector_values)
brain.set_data(data['layer_id'], scalar_data, vectors)
del brain
data["time_idx"] = time_idx

Expand Down Expand Up @@ -3225,24 +3220,25 @@ def _remove_scalar_data(self, array_id):
self._mesh_clones.pop(array_id).remove()
self._mesh_dataset.point_data.remove_array(array_id)

def _add_vector_data(self, vectors, vector_values, fmin, fmid, fmax,
scale_factor_norm, vertices, vector_alpha, lut):
def _add_vector_data(self, vectors, fmin, fmid, fmax,
scale_factor, vertices, vector_alpha, lut):
vertices = slice(None) if vertices is None else vertices
x, y, z = np.array(self._geo_mesh.data.points.data)[vertices].T
vector_alpha = min(vector_alpha, 0.9999999)
with warnings.catch_warnings(record=True): # HasTraits
quiver = mlab.quiver3d(
x, y, z, vectors[:, 0], vectors[:, 1], vectors[:, 2],
scalars=vector_values, colormap='hot', vmin=fmin,
colormap='hot', vmin=fmin, scale_mode='vector',
vmax=fmax, figure=self._f, opacity=vector_alpha)

# Enable backface culling
quiver.actor.property.backface_culling = True
quiver.mlab_source.update()

# Compute scaling for the glyphs
quiver.glyph.glyph.scale_factor = (scale_factor_norm *
vector_values.max())
# Set scaling for the glyphs
quiver.glyph.glyph.scale_factor = scale_factor
quiver.glyph.glyph.clamping = False
quiver.glyph.glyph.range = (0., 1.)

# Scale colormap used for the glyphs
l_m = quiver.parent.vector_lut_manager
Expand Down Expand Up @@ -3293,7 +3289,7 @@ def add_overlay(self, old, **kwargs):

@verbose
def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,
colorbar, layer_id, smooth_mat, magnitude, magnitude_max,
colorbar, layer_id, smooth_mat, magnitude,
scale_factor, vertices, vector_alpha, **kwargs):
"""Add data to the brain"""
# Calculate initial data to plot
Expand All @@ -3308,24 +3304,20 @@ def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,
array_plot = magnitude[:, 0]
else:
raise ValueError("data has to be 1D, 2D, or 3D")
vector_values = array_plot
if smooth_mat is not None:
array_plot = smooth_mat * array_plot

# Copy and byteswap to deal with Mayavi bug
array_plot = _prepare_data(array_plot)

array_id, pipe = self._add_scalar_data(array_plot)
scale_factor_norm = None
if array.ndim == 3:
scale_factor_norm = scale_factor / magnitude_max
vectors = array[:, :, 0].copy()
glyphs = self._add_vector_data(
vectors, vector_values, fmin, fmid, fmax,
scale_factor_norm, vertices, vector_alpha, lut)
vectors, fmin, fmid, fmax,
scale_factor, vertices, vector_alpha, lut)
else:
glyphs = None
del scale_factor
mesh = pipe.parent
if thresh is not None:
if array_plot.min() >= thresh:
Expand Down Expand Up @@ -3364,7 +3356,7 @@ def add_data(self, array, fmin, fmid, fmax, thresh, lut, colormap, alpha,

self.data[layer_id] = dict(
array_id=array_id, mesh=mesh, glyphs=glyphs,
scale_factor_norm=scale_factor_norm)
scale_factor=scale_factor)
return surf, orig_ctable, bar, glyphs

def add_annotation(self, annot, ids, cmap, **kwargs):
Expand Down Expand Up @@ -3475,7 +3467,7 @@ def remove_data(self, layer_id):
self._remove_scalar_data(data['array_id'])
self._remove_vector_data(data['glyphs'])

def set_data(self, layer_id, values, vectors=None, vector_values=None):
def set_data(self, layer_id, values, vectors=None):
"""Set displayed data values and vectors."""
data = self.data[layer_id]
self._mesh_dataset.point_data.get_array(
Expand All @@ -3492,12 +3484,12 @@ def set_data(self, layer_id, values, vectors=None, vector_values=None):

# Update glyphs
q.mlab_source.vectors = vectors
q.mlab_source.scalars = vector_values
q.mlab_source.update()

# Update changed parameters, and glyph scaling
q.glyph.glyph.scale_factor = (data['scale_factor_norm'] *
values.max())
q.glyph.glyph.scale_factor = data['scale_factor']
q.glyph.glyph.range = (0., 1.)
q.glyph.glyph.clamping = False
l_m.load_lut_from_list(lut / 255.)
l_m.data_range = data_range

Expand Down