123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- # axis3d.py, original mplot3d version by John Porter
- # Created: 23 Sep 2005
- # Parts rewritten by Reinier Heeres <reinier@heeres.eu>
- import numpy as np
- from matplotlib import (
- artist, cbook, lines as mlines, axis as maxis, patches as mpatches,
- rcParams)
- from . import art3d, proj3d
- @cbook.deprecated("3.1")
- def get_flip_min_max(coord, index, mins, maxs):
- if coord[index] == mins[index]:
- return maxs[index]
- else:
- return mins[index]
- def move_from_center(coord, centers, deltas, axmask=(True, True, True)):
- """
- For each coordinate where *axmask* is True, move *coord* away from
- *centers* by *deltas*.
- """
- coord = np.asarray(coord)
- return coord + axmask * np.copysign(1, coord - centers) * deltas
- def tick_update_position(tick, tickxs, tickys, labelpos):
- '''Update tick line and label position and style.'''
- tick.label1.set_position(labelpos)
- tick.label2.set_position(labelpos)
- tick.tick1line.set_visible(True)
- tick.tick2line.set_visible(False)
- tick.tick1line.set_linestyle('-')
- tick.tick1line.set_marker('')
- tick.tick1line.set_data(tickxs, tickys)
- tick.gridline.set_data(0, 0)
- class Axis(maxis.XAxis):
- """An Axis class for the 3D plots."""
- # These points from the unit cube make up the x, y and z-planes
- _PLANES = (
- (0, 3, 7, 4), (1, 2, 6, 5), # yz planes
- (0, 1, 5, 4), (3, 2, 6, 7), # xz planes
- (0, 1, 2, 3), (4, 5, 6, 7), # xy planes
- )
- # Some properties for the axes
- _AXINFO = {
- 'x': {'i': 0, 'tickdir': 1, 'juggled': (1, 0, 2),
- 'color': (0.95, 0.95, 0.95, 0.5)},
- 'y': {'i': 1, 'tickdir': 0, 'juggled': (0, 1, 2),
- 'color': (0.90, 0.90, 0.90, 0.5)},
- 'z': {'i': 2, 'tickdir': 0, 'juggled': (0, 2, 1),
- 'color': (0.925, 0.925, 0.925, 0.5)},
- }
- def __init__(self, adir, v_intervalx, d_intervalx, axes, *args,
- rotate_label=None, **kwargs):
- # adir identifies which axes this is
- self.adir = adir
- # This is a temporary member variable.
- # Do not depend on this existing in future releases!
- self._axinfo = self._AXINFO[adir].copy()
- if rcParams['_internal.classic_mode']:
- self._axinfo.update(
- {'label': {'va': 'center',
- 'ha': 'center'},
- 'tick': {'inward_factor': 0.2,
- 'outward_factor': 0.1,
- 'linewidth': rcParams['lines.linewidth']},
- 'axisline': {'linewidth': 0.75,
- 'color': (0, 0, 0, 1)},
- 'grid': {'color': (0.9, 0.9, 0.9, 1),
- 'linewidth': 1.0,
- 'linestyle': '-'},
- })
- else:
- self._axinfo.update(
- {'label': {'va': 'center',
- 'ha': 'center'},
- 'tick': {'inward_factor': 0.2,
- 'outward_factor': 0.1,
- 'linewidth': rcParams.get(
- adir + 'tick.major.width',
- rcParams['xtick.major.width'])},
- 'axisline': {'linewidth': rcParams['axes.linewidth'],
- 'color': rcParams['axes.edgecolor']},
- 'grid': {'color': rcParams['grid.color'],
- 'linewidth': rcParams['grid.linewidth'],
- 'linestyle': rcParams['grid.linestyle']},
- })
- maxis.XAxis.__init__(self, axes, *args, **kwargs)
- # data and viewing intervals for this direction
- self.d_interval = d_intervalx
- self.v_interval = v_intervalx
- self.set_rotate_label(rotate_label)
- def init3d(self):
- self.line = mlines.Line2D(
- xdata=(0, 0), ydata=(0, 0),
- linewidth=self._axinfo['axisline']['linewidth'],
- color=self._axinfo['axisline']['color'],
- antialiased=True)
- # Store dummy data in Polygon object
- self.pane = mpatches.Polygon(
- np.array([[0, 0], [0, 1], [1, 0], [0, 0]]),
- closed=False, alpha=0.8, facecolor='k', edgecolor='k')
- self.set_pane_color(self._axinfo['color'])
- self.axes._set_artist_props(self.line)
- self.axes._set_artist_props(self.pane)
- self.gridlines = art3d.Line3DCollection([])
- self.axes._set_artist_props(self.gridlines)
- self.axes._set_artist_props(self.label)
- self.axes._set_artist_props(self.offsetText)
- # Need to be able to place the label at the correct location
- self.label._transform = self.axes.transData
- self.offsetText._transform = self.axes.transData
- @cbook.deprecated("3.1")
- def get_tick_positions(self):
- majorLocs = self.major.locator()
- majorLabels = self.major.formatter.format_ticks(majorLocs)
- return majorLabels, majorLocs
- def get_major_ticks(self, numticks=None):
- ticks = maxis.XAxis.get_major_ticks(self, numticks)
- for t in ticks:
- t.tick1line.set_transform(self.axes.transData)
- t.tick2line.set_transform(self.axes.transData)
- t.gridline.set_transform(self.axes.transData)
- t.label1.set_transform(self.axes.transData)
- t.label2.set_transform(self.axes.transData)
- return ticks
- def set_pane_pos(self, xys):
- xys = np.asarray(xys)
- xys = xys[:, :2]
- self.pane.xy = xys
- self.stale = True
- def set_pane_color(self, color):
- '''Set pane color to a RGBA tuple.'''
- self._axinfo['color'] = color
- self.pane.set_edgecolor(color)
- self.pane.set_facecolor(color)
- self.pane.set_alpha(color[-1])
- self.stale = True
- def set_rotate_label(self, val):
- '''
- Whether to rotate the axis label: True, False or None.
- If set to None the label will be rotated if longer than 4 chars.
- '''
- self._rotate_label = val
- self.stale = True
- def get_rotate_label(self, text):
- if self._rotate_label is not None:
- return self._rotate_label
- else:
- return len(text) > 4
- def _get_coord_info(self, renderer):
- mins, maxs = np.array([
- self.axes.get_xbound(),
- self.axes.get_ybound(),
- self.axes.get_zbound(),
- ]).T
- centers = (maxs + mins) / 2.
- deltas = (maxs - mins) / 12.
- mins = mins - deltas / 4.
- maxs = maxs + deltas / 4.
- vals = mins[0], maxs[0], mins[1], maxs[1], mins[2], maxs[2]
- tc = self.axes.tunit_cube(vals, renderer.M)
- avgz = [tc[p1][2] + tc[p2][2] + tc[p3][2] + tc[p4][2]
- for p1, p2, p3, p4 in self._PLANES]
- highs = np.array([avgz[2*i] < avgz[2*i+1] for i in range(3)])
- return mins, maxs, centers, deltas, tc, highs
- def draw_pane(self, renderer):
- renderer.open_group('pane3d', gid=self.get_gid())
- mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
- info = self._axinfo
- index = info['i']
- if not highs[index]:
- plane = self._PLANES[2 * index]
- else:
- plane = self._PLANES[2 * index + 1]
- xys = [tc[p] for p in plane]
- self.set_pane_pos(xys)
- self.pane.draw(renderer)
- renderer.close_group('pane3d')
- @artist.allow_rasterization
- def draw(self, renderer):
- self.label._transform = self.axes.transData
- renderer.open_group('axis3d', gid=self.get_gid())
- ticks = self._update_ticks()
- info = self._axinfo
- index = info['i']
- mins, maxs, centers, deltas, tc, highs = self._get_coord_info(renderer)
- # Determine grid lines
- minmax = np.where(highs, maxs, mins)
- maxmin = np.where(highs, mins, maxs)
- # Draw main axis line
- juggled = info['juggled']
- edgep1 = minmax.copy()
- edgep1[juggled[0]] = maxmin[juggled[0]]
- edgep2 = edgep1.copy()
- edgep2[juggled[1]] = maxmin[juggled[1]]
- pep = np.asarray(
- proj3d.proj_trans_points([edgep1, edgep2], renderer.M))
- centpt = proj3d.proj_transform(*centers, renderer.M)
- self.line.set_data(pep[0], pep[1])
- self.line.draw(renderer)
- # Grid points where the planes meet
- xyz0 = np.tile(minmax, (len(ticks), 1))
- xyz0[:, index] = [tick.get_loc() for tick in ticks]
- # Draw labels
- # The transAxes transform is used because the Text object
- # rotates the text relative to the display coordinate system.
- # Therefore, if we want the labels to remain parallel to the
- # axis regardless of the aspect ratio, we need to convert the
- # edge points of the plane to display coordinates and calculate
- # an angle from that.
- # TODO: Maybe Text objects should handle this themselves?
- dx, dy = (self.axes.transAxes.transform([pep[0:2, 1]]) -
- self.axes.transAxes.transform([pep[0:2, 0]]))[0]
- lxyz = 0.5 * (edgep1 + edgep2)
- # A rough estimate; points are ambiguous since 3D plots rotate
- ax_scale = self.axes.bbox.size / self.figure.bbox.size
- ax_inches = np.multiply(ax_scale, self.figure.get_size_inches())
- ax_points_estimate = sum(72. * ax_inches)
- deltas_per_point = 48 / ax_points_estimate
- default_offset = 21.
- labeldeltas = (
- (self.labelpad + default_offset) * deltas_per_point * deltas)
- axmask = [True, True, True]
- axmask[index] = False
- lxyz = move_from_center(lxyz, centers, labeldeltas, axmask)
- tlx, tly, tlz = proj3d.proj_transform(*lxyz, renderer.M)
- self.label.set_position((tlx, tly))
- if self.get_rotate_label(self.label.get_text()):
- angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
- self.label.set_rotation(angle)
- self.label.set_va(info['label']['va'])
- self.label.set_ha(info['label']['ha'])
- self.label.draw(renderer)
- # Draw Offset text
- # Which of the two edge points do we want to
- # use for locating the offset text?
- if juggled[2] == 2:
- outeredgep = edgep1
- outerindex = 0
- else:
- outeredgep = edgep2
- outerindex = 1
- pos = move_from_center(outeredgep, centers, labeldeltas, axmask)
- olx, oly, olz = proj3d.proj_transform(*pos, renderer.M)
- self.offsetText.set_text(self.major.formatter.get_offset())
- self.offsetText.set_position((olx, oly))
- angle = art3d._norm_text_angle(np.rad2deg(np.arctan2(dy, dx)))
- self.offsetText.set_rotation(angle)
- # Must set rotation mode to "anchor" so that
- # the alignment point is used as the "fulcrum" for rotation.
- self.offsetText.set_rotation_mode('anchor')
- #----------------------------------------------------------------------
- # Note: the following statement for determining the proper alignment of
- # the offset text. This was determined entirely by trial-and-error
- # and should not be in any way considered as "the way". There are
- # still some edge cases where alignment is not quite right, but this
- # seems to be more of a geometry issue (in other words, I might be
- # using the wrong reference points).
- #
- # (TT, FF, TF, FT) are the shorthand for the tuple of
- # (centpt[info['tickdir']] <= pep[info['tickdir'], outerindex],
- # centpt[index] <= pep[index, outerindex])
- #
- # Three-letters (e.g., TFT, FTT) are short-hand for the array of bools
- # from the variable 'highs'.
- # ---------------------------------------------------------------------
- if centpt[info['tickdir']] > pep[info['tickdir'], outerindex]:
- # if FT and if highs has an even number of Trues
- if (centpt[index] <= pep[index, outerindex]
- and np.count_nonzero(highs) % 2 == 0):
- # Usually, this means align right, except for the FTT case,
- # in which offset for axis 1 and 2 are aligned left.
- if highs.tolist() == [False, True, True] and index in (1, 2):
- align = 'left'
- else:
- align = 'right'
- else:
- # The FF case
- align = 'left'
- else:
- # if TF and if highs has an even number of Trues
- if (centpt[index] > pep[index, outerindex]
- and np.count_nonzero(highs) % 2 == 0):
- # Usually mean align left, except if it is axis 2
- if index == 2:
- align = 'right'
- else:
- align = 'left'
- else:
- # The TT case
- align = 'right'
- self.offsetText.set_va('center')
- self.offsetText.set_ha(align)
- self.offsetText.draw(renderer)
- if self.axes._draw_grid and len(ticks):
- # Grid lines go from the end of one plane through the plane
- # intersection (at xyz0) to the end of the other plane. The first
- # point (0) differs along dimension index-2 and the last (2) along
- # dimension index-1.
- lines = np.stack([xyz0, xyz0, xyz0], axis=1)
- lines[:, 0, index - 2] = maxmin[index - 2]
- lines[:, 2, index - 1] = maxmin[index - 1]
- self.gridlines.set_segments(lines)
- self.gridlines.set_color(info['grid']['color'])
- self.gridlines.set_linewidth(info['grid']['linewidth'])
- self.gridlines.set_linestyle(info['grid']['linestyle'])
- self.gridlines.draw(renderer, project=True)
- # Draw ticks
- tickdir = info['tickdir']
- tickdelta = deltas[tickdir]
- if highs[tickdir]:
- ticksign = 1
- else:
- ticksign = -1
- for tick in ticks:
- # Get tick line positions
- pos = edgep1.copy()
- pos[index] = tick.get_loc()
- pos[tickdir] = (
- edgep1[tickdir]
- + info['tick']['outward_factor'] * ticksign * tickdelta)
- x1, y1, z1 = proj3d.proj_transform(*pos, renderer.M)
- pos[tickdir] = (
- edgep1[tickdir]
- - info['tick']['inward_factor'] * ticksign * tickdelta)
- x2, y2, z2 = proj3d.proj_transform(*pos, renderer.M)
- # Get position of label
- default_offset = 8. # A rough estimate
- labeldeltas = (
- (tick.get_pad() + default_offset) * deltas_per_point * deltas)
- axmask = [True, True, True]
- axmask[index] = False
- pos[tickdir] = edgep1[tickdir]
- pos = move_from_center(pos, centers, labeldeltas, axmask)
- lx, ly, lz = proj3d.proj_transform(*pos, renderer.M)
- tick_update_position(tick, (x1, x2), (y1, y2), (lx, ly))
- tick.tick1line.set_linewidth(info['tick']['linewidth'])
- tick.draw(renderer)
- renderer.close_group('axis3d')
- self.stale = False
- # TODO: Get this to work properly when mplot3d supports
- # the transforms framework.
- def get_tightbbox(self, renderer):
- # Currently returns None so that Axis.get_tightbbox
- # doesn't return junk info.
- return None
- @property
- def d_interval(self):
- return self.get_data_interval()
- @d_interval.setter
- def d_interval(self, minmax):
- return self.set_data_interval(*minmax)
- @property
- def v_interval(self):
- return self.get_view_interval()
- @v_interval.setter
- def v_interval(self, minmax):
- return self.set_view_interval(*minmax)
- # Use classes to look at different data limits
- class XAxis(Axis):
- get_view_interval, set_view_interval = maxis._make_getset_interval(
- "view", "xy_viewLim", "intervalx")
- get_data_interval, set_data_interval = maxis._make_getset_interval(
- "data", "xy_dataLim", "intervalx")
- class YAxis(Axis):
- get_view_interval, set_view_interval = maxis._make_getset_interval(
- "view", "xy_viewLim", "intervaly")
- get_data_interval, set_data_interval = maxis._make_getset_interval(
- "data", "xy_dataLim", "intervaly")
- class ZAxis(Axis):
- get_view_interval, set_view_interval = maxis._make_getset_interval(
- "view", "zz_viewLim", "intervalx")
- get_data_interval, set_data_interval = maxis._make_getset_interval(
- "data", "zz_dataLim", "intervalx")
|