Module curlew.visualise
Functions for performing crude 2D plotting using matplotlib. Useful for demonstrations, but will need to be extended at some point to be more usable during model development.
Functions
def colour(sf, cmap='tab20', breaks=19)
-
Expand source code
def colour( sf, cmap='tab20', breaks=19 ): """ Apply a Matplotlib colormap to a scalar field to generate a colorized property set. The function maps scalar field values to colors using a specified colormap and returns the colors as uint8 values between 0 and 255. Parameters ---------- sf : np.ndarray A scalar field array representing the values to be colorized. cmap : str, optional The name of the Matplotlib colormap to use. Default is 'tab20'. breaks : int or array-like, optional If an integer, defines the number of breakpoints used to segment the scalar field. If an array-like object, specifies the exact breakpoints. Returns ------- np.ndarray An array of shape matching `sf` with RGB color values mapped to the colormap, returned as uint8 values in the range [0, 255]. """ from matplotlib.colors import BoundaryNorm # do here so matplotlib is not a mandatory dependency import matplotlib.pyplot as plt if isinstance(breaks, int): breaks = np.hstack( [np.min(sf), np.linspace( np.min(sf), np.max(sf), breaks)] ) else: breaks = np.hstack( [np.min(sf), breaks, np.max(sf) ] ) cm = plt.get_cmap(cmap) n = cm.N assert n >= len(breaks), "Number of breaks (%d) must be less than number of colours in colormap (%d)"%(len(breaks), n) norm = BoundaryNorm( breaks, ncolors=n ) c = cm( norm( sf ) )[..., :3] c = (c*255).astype(np.uint8) return c
Apply a Matplotlib colormap to a scalar field to generate a colorized property set.
The function maps scalar field values to colors using a specified colormap and returns the colors as uint8 values between 0 and 255.
Parameters
sf
:np.ndarray
- A scalar field array representing the values to be colorized.
cmap
:str
, optional- The name of the Matplotlib colormap to use. Default is 'tab20'.
breaks
:int
orarray-like
, optional- If an integer, defines the number of breakpoints used to segment the scalar field. If an array-like object, specifies the exact breakpoints.
Returns
np.ndarray
- An array of shape matching
sf
with RGB color values mapped to the colormap, returned as uint8 values in the range [0, 255].
def format_latex_subscript(name)
-
Expand source code
def format_latex_subscript(name): """ Converts 'f2' to LaTeX format '$f_2$'. Works with multiple letters and digits (e.g., 'sigma12' → '$\\sigma_{12}$'). """ import re match = re.match(r"^([a-zA-Z]+)([0-9]+)$", name) if match: return f"${match.group(1)}_{{{match.group(2)}}}$" else: return f"${name}$" # fallback if not matching pattern
Converts 'f2' to LaTeX format '$f_2$'. Works with multiple letters and digits (e.g., 'sigma12' → '$\sigma_{12}$').
def get_positions(M, G, node, first_x=0, first_y=0, step_x=10, step_y=5, pos=None)
-
Expand source code
def get_positions(M, G, node, first_x=0, first_y=0, step_x=10, step_y=5, pos=None): """ Recursively calculate the positions of nodes in a hierarchical structure. Parameters ---------- G : networkx.Digraph The directed graph containing the nodes to be positioned. node : str The current node for which to calculate the position. first_x : int, optional The initial x-coordinate for the current node. first_y : int, optional The initial y-coordinate for the current node. step_x : int, optional The horizontal step size for moving to the right. step_y : int, optional The vertical step size for moving down. pos : dict, optional A dictionary to store the positions of nodes. Returns ------- pos : dict A dictionary mapping each node to its (x, y) position. """ if pos is None: pos = {} # Assign the position to the current node pos[node] = (first_x, first_y) # Get the children of the current node children = list(G.successors(node)) if not children: return pos # If the node is a domain boundary, handle its children differently node_field = next((field for field in M.fields if field.name == node), None) if node_field.parent2 is not None and isinstance(node_field, SF): # Move to the right pos = get_positions(M, G, children[0], first_x + step_x, first_y, step_x, step_y, pos) # Move down pos = get_positions(M, G, children[1], first_x, first_y - step_y, step_x, step_y, pos) else: # For non-domain boundary nodes, move to the right pos = get_positions(M, G, children[0], first_x + step_x, first_y, step_x, step_y, pos) return pos
Recursively calculate the positions of nodes in a hierarchical structure.
Parameters
G
:networkx.Digraph
- The directed graph containing the nodes to be positioned.
node
:str
- The current node for which to calculate the position.
first_x
:int
, optional- The initial x-coordinate for the current node.
first_y
:int
, optional- The initial y-coordinate for the current node.
step_x
:int
, optional- The horizontal step size for moving to the right.
step_y
:int
, optional- The vertical step size for moving down.
pos
:dict
, optional- A dictionary to store the positions of nodes.
Returns
pos
:dict
- A dictionary mapping each node to its (x, y) position.
def plot2D(sxy, grid, C=None, ticksize=50, lw=1, cmap='rainbow', levels=None, ax=None, alpha=0.3)
-
Expand source code
def plot2D( sxy, grid, C=None, ticksize=50, lw=1, cmap='rainbow', levels=None, ax=None, alpha=0.3 ): """ Create a 2D plot of a scalar field and optionally overlay associated constraints. Parameters ---------- sxy : np.ndarray A (N,) array of scalar values corresponding to the above points, or an array of shape (N, 3) containing RGB values to plot instead. grid : curlew.geometry.Grid A grid defining the points at which the values of `sxy` are located. C : np.ndarray, optional A constraint set containing the (2D) points to overlay on the plot. ticksize : int, optional The size of the orientation ticks to add to the plot. lw : float, optional The linewidth to use for plotting orientation constraints. levels : list or None or bool, optional A list of contour levels to plot. Set to `None` for automatic selection, or `False` to disable contours. ax : matplotlib.axes.Axes, optional A Matplotlib axes object on which to plot. Returns ------- matplotlib.figure.Figure, matplotlib.Axes The generated Matplotlib figure and associated axes. """ import matplotlib.pyplot as plt if ax is None: ax = plt.gca() # get current axes vmn, vmx = None, None pxy = grid.coords() shape = grid.shape if (pxy is not None) and (sxy is not None): xmn,xmx = np.percentile(pxy[:,0], (0,100)) # get bounds ymn,ymx = np.percentile(pxy[:,1], (0,100)) if (sxy.shape[-1] == 3) or (sxy.shape[-1] == 4): # RGB or RGBA colours # plot colours directly si = sxy.reshape( shape + (sxy.shape[-1],) ) ax.imshow( np.transpose( si, (1,0,2)), alpha=alpha, extent=(xmn,xmx,ymn,ymx), origin='lower' ) else: si = sxy.reshape(shape) # reshape to image vmn,vmx = np.percentile(sxy, (0,100)) # plot scalar field and countours ax.imshow(si.T, cmap=cmap, alpha=alpha, extent=(xmn,xmx,ymn,ymx), vmin=vmn, vmax=vmx, origin='lower' ) if not (isinstance(levels, bool) and (levels == False)): contour = ax.contour(si.T, cmap=cmap,levels=levels, extent=(xmn,xmx,ymn,ymx), vmin=vmn, vmax=vmx) ax.clabel(contour, inline=True, fontsize=12) # plot data if C is not None: # plot value constraints if (C.vp is not None) and (C.pp is None): # don't plot value constraints if a property constraint is defined if vmn is None: vmn,vmx = np.percentile( C.vv.squeeze(), (0,100) ) ax.scatter( C.vp[:,0], C.vp[:,1], c=C.vv.squeeze(), cmap=cmap, vmin=vmn, vmax=vmx, edgecolors='k', zorder=10, s=ticksize ) # plot gradient constraints if C.gp is not None: gp = C.gp gv = C.gv for i,v in enumerate(gv): ax.plot( [ gp[i][0], gp[i][0] + v[0]*ticksize ], [ gp[i][1], gp[i][1] + v[1]*ticksize ], color='orange', lw=lw, zorder=10 ) dx = ticksize*v[1] dy = -ticksize*v[0] ax.plot([ gp[i][0]-dx, gp[i][0]+dx],[ gp[i][1]-dy, gp[i][1]+dy], color='k', lw=lw, zorder=10 ) # plot orientation constraints if C.gop is not None: gp = C.gop gv = C.gov for i,v in enumerate(gv): dx = ticksize*v[1] dy = -ticksize*v[0] ax.plot([ gp[i][0]-dx, gp[i][0]+dx],[ gp[i][1]-dy, gp[i][1]+dy], color='k', lw=lw, zorder=10 ) # plot property constraints if (C.pp is not None) and (C.pv is not None): if (C.pv.shape[-1] == 1): ax.scatter( C.pp[:,0], C.pp[:,1], c=C.pv[:,0]) elif (C.pv.shape[-1] >= 3): rgb = C.pv[:,[0,1,2]].astype(float) rgb -= np.min(rgb, axis=0)[None,:] rgb /= np.max(rgb, axis=0)[None,:] ax.scatter( C.pp[:,0], C.pp[:,1], c=rgb, s=ticksize/2) # plot grid for evaluating global constraints if C.sgrid is not None: ax.scatter( C.sgrid[:,0], C.sgrid[:,1], color='gray', s=ticksize/3 ) return ax.get_figure(), ax
Create a 2D plot of a scalar field and optionally overlay associated constraints.
Parameters
sxy
:np.ndarray
- A (N,) array of scalar values corresponding to the above points, or an array of shape (N, 3) containing RGB values to plot instead.
grid
:Grid
- A grid defining the points at which the values of
sxy
are located. C
:np.ndarray
, optional- A constraint set containing the (2D) points to overlay on the plot.
ticksize
:int
, optional- The size of the orientation ticks to add to the plot.
lw
:float
, optional- The linewidth to use for plotting orientation constraints.
levels
:list
orNone
orbool
, optional- A list of contour levels to plot. Set to
None
for automatic selection, orFalse
to disable contours. ax
:matplotlib.axes.Axes
, optional- A Matplotlib axes object on which to plot.
Returns
matplotlib.figure.Figure, matplotlib.Axes
- The generated Matplotlib figure and associated axes.
def plotConstraints(ax, C=None, H=None, ll=1, lw=4, scale=0.001, ac='k', vmn=0, vmx=20, cmap='tab20b')
-
Expand source code
def plotConstraints(ax, C=None, H=None, ll=1, lw=4, scale=0.001, ac="k", vmn=0, vmx=20, cmap="tab20b"): if C is not None: if (H is None) or (H.value_loss != 0): if (C.vp is not None) and (C.pp is None): # don't plot value constraints if a property constraint is defined if vmn is None: vmn, vmx = np.percentile( C.vv.squeeze(), (0,100)) ax.scatter(C.vp[:,0], C.vp[:,1], c=C.vv.squeeze(), cmap=cmap, vmin=vmn, vmax=vmx, zorder=12, edgecolor=ac) # plot gradient constraints if (H is None) or (H.grad_loss != 0): if C.gp is not None: # Extract the positions and gradients positions = C.gp gradients = C.gv # Normalize perpendicular vectors for consistent length norms = np.linalg.norm(gradients, axis=1, keepdims=True) gradients_unit = gradients / (norms + 1e-8) # Plot gradients ax.quiver( positions[:, 0], positions[:, 1], gradients_unit[:, 0], gradients_unit[:, 1], color=ac, angles='xy', scale_units='xy', scale=scale, zorder=10, width=lw, ) # Compute perpendicular bedding orientations by rotating gradients 90 degrees perp_gradients = np.vstack([-gradients_unit[:, 1], gradients_unit[:, 0]]).T # Plot perpendicular bedding orientations t = ax.quiver( positions[:, 0], positions[:, 1], ll * perp_gradients[:, 0], ll * perp_gradients[:, 1], color=ac, angles='xy', scale_units='xy', scale=scale, zorder=11, width=lw, headlength=0, headwidth=0, headaxislength=0, pivot="middle" ) # Plot orientations if (H is None) or (H.ori_loss != 0): if C.gop is not None: # Extract the positions and gradients positions = C.gp gradients = C.gv # Normalize perpendicular vectors for consistent length norms = np.linalg.norm(gradients, axis=1, keepdims=True) gradients_unit = gradients / (norms + 1e-8) # Compute perpendicular bedding orientations by rotating gradients 90 degrees perp_gradients = np.vstack([-gradients_unit[:, 1], gradients_unit[:, 0]]).T # Plot perpendicular bedding orientations ax.quiver( positions[:, 0], positions[:, 1], ll * perp_gradients[:, 0], ll * perp_gradients[:, 1], scale=scale, color=ac, angles='xy', scale_units='xy', zorder=10, width=lw, headlength=0, headwidth=0, headaxislength=0, pivot="middle" ) # Plot property constraints if (H is None) or (H.prop_loss != 0): if (C.pp is not None) and (C.pv is not None): if (C.pv.shape[-1] == 1): ax.scatter( C.pp[:,0], C.pp[:,1], c=C.pv[:,0]) elif (C.pv.shape[-1] >= 3): rgb = C.pv[:,[0,1,2]].astype(float) rgb -= np.min(rgb, axis=0)[None,:] rgb /= np.max(rgb, axis=0)[None,:] ax.scatter( C.pp[:,0], C.pp[:,1], c=rgb, s=lw/2)
def plotDrill(hole, ax, ticksize=50, lw=1, vmn=0, vmx=20, cmap='tab20b', noval=False)
-
Expand source code
def plotDrill(hole, ax, ticksize=50, lw=1, vmn=0, vmx=20, cmap="tab20b", noval=False): """ Plot the specified drillhole on the given axis with color normalization. """ from matplotlib.collections import LineCollection from matplotlib.colors import Normalize import matplotlib.patheffects as pe points = np.array([hole['pos'][:, 0], hole['pos'][:, 1]]).T.reshape(-1, 1, 2) segments = np.concatenate([points[:-1], points[1:]], axis=1) # Normalization to handle vmin/vmax norm = Normalize(vmin=vmn, vmax=vmx) # Plot the borehole lc = LineCollection( segments, array=hole['classID'], cmap=cmap, norm=norm, linewidths=lw ) t = ax.add_collection(lc) t.set_path_effects([pe.Stroke(linewidth=lw+5, foreground='k'), pe.Normal()])
Plot the specified drillhole on the given axis with color normalization.
def showModel(M, axs=None, leg_loc=None, title='c)', node_size=3000, font_size=18)
-
Expand source code
def showModel(M, axs=None, leg_loc=None, title="c)", node_size=3000, font_size=18): """ Visualize the model tree of a GeoModel. Parameters ---------- None """ import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec from matplotlib.lines import Line2D import networkx as nx # Create an empty graph graph = nx.DiGraph() domain_boundary_color = '#E35B0E' dilative_event_color = "#F0C419" generative_event_color = '#31B4C2' kinematic_event_color = "#A6340B" fixed_value_color = "#FAE8B6" for field in M.fields[::-1]: # Determine the color based on the event type color = None if field.parent2 is not None: # domain boundary color = domain_boundary_color elif field.bound is not None and field.deformation is not None: # dilative event color = dilative_event_color elif field.bound is not None: # generative event color = generative_event_color elif field.deformation is not None: # kinematic event color = kinematic_event_color graph.add_node(field.name, label=format_latex_subscript(field.name), color=color) # Add edges if isinstance(field.parent, SF): graph.add_edge(field.name, field.parent.name) if isinstance(field.parent2, SF): graph.add_edge(field.name, field.parent2.name) if not isinstance(field.parent, SF) and field.parent is not None: # Handle fixed values graph.add_edge(field.name, str(field.parent)) if not isinstance(field.parent2, SF) and field.parent2 is not None: graph.add_edge(field.name, str(field.parent2)) # Plotting if axs is None: fig = plt.figure(figsize=(10, 6)) fig.tight_layout() gs = gridspec.GridSpec(1, 2, width_ratios=[2, 1], wspace=0.01) ax_graph = fig.add_subplot(gs[0]) ax_legend = fig.add_subplot(gs[1]) else: ax_graph = axs ax_legend = ax_graph.inset_axes(leg_loc) # Model Tree pos = get_positions(M, graph, list(graph.nodes())[0]) node_colors = [graph.nodes[n].get('color', fixed_value_color) for n in graph.nodes()] labels = {n: graph.nodes[n].get("label", str(n)) for n in graph.nodes()} nx.draw(graph, pos, with_labels=True, labels=labels, arrows=True, node_size=node_size, node_color=node_colors, font_size=font_size, font_color="k", ax=ax_graph) ax_graph.set_title(title, loc="left") # Legend legend_labels = { domain_boundary_color : 'Domain Boundary', dilative_event_color : 'Dilative Event', generative_event_color: 'Generative Event', kinematic_event_color : 'Kinematic Event', fixed_value_color : 'Fixed Value' } ax_legend.axis('off') handles = [Line2D([0], [0], marker='o', color='k', label=label, markerfacecolor=color, markersize=15) for color, label in legend_labels.items()] ax_legend.legend(handles=handles, loc='center', fontsize=12) if axs is None: plt.close(fig) return fig else: pass
Visualize the model tree of a GeoModel.
Parameters
None