Module curlew.geology.geomodel

A class representing a time-aware geological model and facilitating interactions with the underlying linked-list of GeoField instances (that represent each geological structure in the model).

Classes

class GeoModel (fields: list, transform=None, grid=None, name=None)
Expand source code
class GeoModel(object):
    """
    A class representing a time-aware geological model and
    facilitating interactions with the underlying linked-list of GeoField instances
    (that represent each geological structure in the model).
    """

    def __init__( self, fields : list, transform=None, grid=None, name=None):
        """
        Construct a GeoModel from a list of GeoFields.

        Parameters
        ----------
        fields : list
            A list of GeoField instances representing geological events, from oldest to youngest. This list 
            can include domain boundaries if needed, but non-domain fields (e.g., faults, stratigraphy, etc.)
            should not be older than these.
        transform : `curlew.core.Transform`
            A Transform object defining the transform from global coordinates to model coordinates. This will be applied
            to all `x` when `self.predict(x)` is called, and can handle e.g., converting UTM to some model coordinate system.
            Defaults to an identity matrix (no transform).
        grid : curlew.geometry.Grid | optional
            An optional grid to associate with this GeoModel instance. This will set the `M.grid` variable but is not
            necessary (i.e. can be `null`; which is the default).
        name : str | optional
            A string name to associate with this GeoModel. Not really used, but can be useful :-)
        """
        # set parent and child properties of underlying GeoFields
        # (i.e. build our linked list / binary tree of GeoFields)
        _linkF(fields)

        # traverse back down linked list / tree and define IDs
        def traverse( node, i = 1):
            node.eid = i
            if isinstance( node.parent, GeoField ):
                i = traverse( node.parent, i+1)
            if isinstance( node.parent2, GeoField):
                i = traverse( node.parent2, i+1)
            return i
        traverse( fields[-1] )

        # accumulate all fields in this model
        self.fields = []
        def traverse_fields( node ):
            if isinstance(node, GeoField):
                self.fields.append(node)
            if isinstance(node.parent, GeoField):
                traverse_fields( node.parent )
            if isinstance(node.parent2, GeoField):
                traverse_fields( node.parent2 )
        traverse_fields( fields[-1] ) # traverse from last field in the list
        self.fields = self.fields[::-1] # we want the youngest field last, so reverse the list
        self.lastEvent = self.fields[-1] # change to evaluate model in some paleo-space
        self.eidLookup = { f.eid : f for f in self.fields } # create a lookup table for translating event IDs to GeoField instances

        self.input_dim = self.fields[0].input_dim # get dimensionality of model from one of the fields in the model
        for f in self.fields: # check dimensionalities all match
            assert f.input_dim == self.input_dim, f"Field {f.name} has a dimensionality of {f.input_dim} not {self.input_dim}."
        if transform is None:
            self.T = Transform(self.input_dim) # thunk -- leave this as None and skip compute? (slightly faster)
        else:
            self.T = transform

        # store "nice-to-have" extras
        self.grid = grid
        self.name = name

    def freeze( self, name=None, geometry=True, params=False ):
        """
        Freeze the specified field or parameter. Used to e.g., optimise
        fault offset while keeping fault geometry fixed.

        Parameters
        ------------
        name, str | GeoField | list:
            The name of the GeoField to freeze. Can also be a list of names or instances. If None, 
            the specified freeze will be applied to all GeoFields in this model. Use `'forward'` to 
            address any defined forward model.
        geometry : bool
            True if the geometry of the specified GeoField should be frozen. Default is True. 
        params : bool
            True if other parameters (e.g., fault slip) associated with the specified GeoFields should be frozen. Default is False.
        """
        if name is None:
            name = [f for f in self.fields] # apply to all
        if not isinstance(name, list) or isinstance(name, tuple):
            name = [name]
        for f in name:
            if isinstance(f, str) or isinstance(f, int):
                f = self[f] # get field by name or ID
            f.field.frozen = geometry # freeze geometry?
            if f.deformation is not None: # freeze potentially learnable properties?
                f.deformation.frozen = params
            if f.propertyField is not None: # freeze potentially learnable properties?
                f.propertyField.frozen = params
            if f.overprint is not None: # freeze potentially learnable properties?
                f.overprint.frozen = params

    def prefit(self, epochs, **kwargs):
        """
        Train all GeoFields in this model to fit their respective constraints
        in isolation, starting with the youngest field.

        Parameters
        ----------
        epochs : int
            The number of epochs to train for.
        
        Keywords
        ----------
        All keywords are passed to `curlew.fields.NF.fit(...)`. These include:
        learning_rate : float, optional
            Reset each GeoField's optimiser to the specified learning rate before training.
        best : bool, optional
            After training set neural field weights to the best loss.
        vb : bool, optional
            Display a tqdm progress bar to monitor training.

        Returns
        -------
        loss : float
            The loss of the final (best if best=True) model state.
        details : dict
            A more detailed breakdown of the final loss. 
        """
        out = {}
        for F in self.fields[::-1]:
            _, loss = F.fit( epochs, prefix=F.name, **kwargs )
            out.update(loss) # add outputs
        return out

    def zero(self):
        """
        Zero all (unfrozen) optimisers associated with the neural fields and 
        other learned parameters (e.g., fault offsets) in this model. 
        """
        for f in self.fields:
            f.zero()

    def step(self):
        """
        Step all (unfrozen) optimisers associated with the fields making
        up this model, and (potentially) other leanrned parameters (e.g., fault offsets). 
        """
        for f in self.fields:
            f.step()

    def fit(self, epochs, learning_rate=None, early_stop=(100, 1e-4), best=True, vb=True, prefix='Training'):
        """
        Train all GeoFields in this model to fit the specified constraints
        simultaneously.

        Parameters
        ----------
        epochs : int
            The number of epochs to train each GeoField for.
        learning_rate : float, optional
            Reset each GeoField's optimiser to the specified learning rate before training.
        early_stop : tuple,
            Tuple containing early stopping criterion. This should be (n,t) such that optimisation
            stops after n iterations with <= t improvement in the loss. Set to None to disable. Note 
            that early stopping is only applied if `best = True`. 
        best : bool, optional
            After training set the neural field weights to the best loss.
        vb : bool, optional
            Display a tqdm progress bar to monitor training.
        prefix : str, optional
            The prefix used for the tqdm progress bar.

        Returns
        -------
        loss : float
            The loss of the final (best if best=True) model state.
        details : dict
            A more detailed breakdown of the final loss. 
        """

        # set learning rate if specified
        if learning_rate is not None:
            for F in self.fields:
                F.set_rate( learning_rate )
        
        # setup progress bar
        bar = range(epochs)
        if vb:
            bar = tqdm(range(epochs), desc=prefix, bar_format="{desc}: {n_fmt}/{total_fmt}|{postfix}")

        # iterate
        out = {}
        #best_state = []
        best_loss = np.inf
        best_count = 0
        eps = 0
        if early_stop is not None:
            eps = early_stop[1]
        for epoch in bar:
            loss = 0
            for F in self.fields[::-1]:
                ll, details = F.loss() # compute loss for this field
                loss = loss + ll # accumulate loss
                out.update(details) # store for output

            # also add forward (property) reconstruction loss
            # if self.forward is not None:
            #     pp = self.forward.C.pp # position of property constraints
            #     pv = self.forward.C.pv # value of property constraints
            #     spred = self.fields[-1].predict(pp, combine=True, to_numpy=False) # automatically recursed back throught the linked list.
            #     # One Hot encoding
            #     if self.forward.H.one_hot:
            #         one_hot_encoder = torch.nn.functional.one_hot((spred[:, 1] - 1).long(), num_classes=len(self.fields))
            #         encoded_spred = one_hot_encoder * spred[:, 0][:, None]
            #         ppred = self.forward( encoded_spred )
            #     else:
            #         ppred = self.forward( spred ) # generate property predictions
            #     prop_loss = self.forward.loss_func( ppred, pv ) # compute loss
            #     if isinstance( self.forward.H.prop_loss, str):
            #         self.forward.H.prop_loss = float(self.forward.H.prop_loss) / prop_loss.item()
            #     loss = loss + self.forward.H.prop_loss * prop_loss
            #     out['forward'] = (prop_loss.item(),{})

            # store best state(s)
            if (loss.item() < (best_loss+eps)):
                #best_state = [ copy.deepcopy( F.field.state_dict()  ) for F in self.fields ]
                best_loss = loss.item()
                best_count = 0
            else:
                best_count += 1

            # early stopping
            if (early_stop is not None) and (best_count > early_stop[0]):
                break

            # update progress
            if vb:
                bar.set_postfix({ k : v[0] for k,v in out.items() })

            self.zero() # zero gradients
            loss.backward() # backprop losses
            self.step()

        # set best state
        #if best_state:
        #    for i,F in enumerate(self.fields):
        #        F.field.load_state_dict(best_state[i])

        # return
        return loss.item(), out

    def predict(self, x : np.ndarray, coords="global", **kwargs):
        """
        Create model predictions at the specified points.

        Parameters
        ----------
        x : np.ndarray | torch.tensor | curlew.geometry.Grid
            An array of shape (N, input_dim) containing the coordinates at which to evaluate
            this GeoModel.
        coords : str
            Specify which coordinate system `x` is in. If `coords == "global"` (default), then any defined
            model transform will be applied (to derive model coordinates). If `coords=="model"` then this
            transform will not be applied.
        
        Keywords
        --------
        All keywords are passed directly to `GeoField.predict()`.

        Returns
        --------
        S : An array of shape (N,1) containig the predicted scalar values and corresponding GeoField
            that "created" them.
        """

        # update isosurface lookup (incase the defined isosurfaces have been changed)
        # build lithology lookup (to ensure lithologies from different fields get unique IDs)
        self.llookup = {}
        self.eidLookup = { f.eid : f for f in self.fields } # create a lookup table for translating event IDs to GeoField instances
        n=1 # start at 1, as -1 is 'undefined' and 0 is default for fields with no lithology defined.
        for F in self.fields:
            self.llookup[F.name] = n # potential lithology created by this field (e.g., constant fields)
            n = n + 1
            if F.overprint is not None:  # only relevant for generative (overprinting) events [ as these "create" new rocks ]
                for k in F.isosurfaces.keys():
                    k = f"{F.name}_{k}" # build key using field name and lithology name
                    assert k not in self.llookup, f"All isosurfaces in model must have unique names, but {k} is not unique!"
                    self.llookup[k] = n # assign ID for this lithology
                    n = n + 1 # increment ID
            F.llookup = self.llookup # link lookup to field so it is used during predict(...).
        
        grid = None
        if isinstance(x, Grid):
            grid = x
            x = grid.coords()

        # apply transform to x
        if "global" in coords.lower():
            x = self.T(x) # transform from world to model coordinates

        # generate predictions
        kwargs['to_numpy'] = kwargs.get('to_numpy', True)
        kwargs['combine'] = True # this is necessary....
        out = self.fields[-1].predict(x, **kwargs) # automatically recursed back throught the linked list.
        
        out.grid = grid
        if "global" in coords.lower():
            out.x = x # replace with global coords
            out.crs = "global"
         
        # return
        return out

    def drill( self, start, end, step ):
        """
        Evaluate the model along a line between start and end with an interval of step.

        Parameters
        -----------
        start : np.ndarray
            The start coordinate of the "drillhole"
        end : np.ndarray
            The end coordinate of the "drillhole"
        step : float
            The distance between points along this line

        Returns
        ---------
        drillholes : Geode
            A Geode instance containing the results given by evaluating the model along the drillhole.
        contacts : Geode
            A Geode instance containing the positions and orientations of contacts intersected along the drillhole.
        """
        dir = np.array(end) - np.array(start)
        length = np.linalg.norm(dir)
        dir = (dir / length)*step
        pos = np.array([start+dir*i for i in range( int(length / step) ) ])

        # evaluate model along drillholes
        g = self.predict( pos )

        # find contacts
        c = None
        g._contactMask = np.abs( np.diff( g.lithoID, prepend=g.lithoID[0] ) ) > 0
        if g._contactMask.any():
            cpos = pos[g._contactMask]
            c = self.predict( cpos, gradient=True ) # predict again, at the contact points only

        # return Geode
        return g, c

    # def evaluate( self, grid, topology=False, buffer=None, surfaces=None, vb=True):
    #     """
    #     Evaluate a *curlew* model on a grid and extract isosurfaces, topology and/or fault buffers.

    #     Parameters
    #     ----------
    #     grid : curlew.geometry.Grid | np.ndarray
    #         A structured Grid to evaluate the model on (if surfaces are to be calculated), or an array
    #         of coordinates (unstructured grid). Isosurfaces cannot be calculate for unstructured grids.
    #     topology : bool, optional
    #         True if model topology (fault hangingwall and footwall relations) should be calculated and returned. Default is False. 
    #     buffer : float, optional
    #         If not None, this distance (in model coordinates) will be used to compute a buffer of this size on either side of each fault surface.
    #     surfaces : str | bool, optional
    #         If not None, isosurfaces will be computed and returned. If a string is passed, these will also be saved to PLY in the specified folder.

    #     Returns
    #     -------
    #     A dict containing some of the following keys: 'topology', 'buffer', 'surfaces'.
    #     """

    #     # TODO - extend this to include e.g., lithological classifications, stratigraphic contacts, etc.
    #     from curlew.geometry import Grid
    #     if isinstance(grid, Grid):
    #         gxy = grid.coords()
    #     else:
    #         surfaces = None # disable surfaces
    #         gxy = grid

    #     # setup output array
    #     out = dict()
    #     if buffer:
    #         out['buffer'] = np.zeros( len(gxy) ) # initialise fault buffer
    #     if topology:
    #         out['topology'] = np.zeros( (len(gxy), len(self.fields)) ) # array to store hanging-wall & footwall information
    #     if surfaces:
    #         out['surfaces'] = {}

    #     # recurse through model extracting required info
    #     def recurse( f, dmask, i=0 ):
    #         # evaluate model
    #         if (f.parent2 is not None) or (f.deformation is not None): # ignore stratigraphic fields
    #             if vb:
    #                 print(f"Evaluating field {i}/{len(self.fields)}", end='\r')
    #             pred = batchEval( gxy, f.predict, vb=False)[:,0]
    #             pred[dmask] = np.nan # remove masked areas

    #         # evaluate topology, buffer & recurse
    #         if f.parent2 is not None: # this is a domain boundary
    #             iso = f.getIsovalue( f.bound )

    #             if buffer:
    #                 i0 = f.getIsovalue( f.bound, offset=-buffer ) # lower buffer
    #                 i1 = f.getIsovalue( f.bound, offset=buffer ) # upper buffer
    #                 out['buffer'][ (pred >= min(i0, i1)) & (pred <= max(i0, i1)) & (out['buffer'] == 0) ] = f.eid

    #             footwall = pred < iso
    #             hangingwall = pred >= iso
    #             if topology:
    #                 out['topology'][ footwall, i ] = -1
    #                 out['topology'][ hangingwall, i ] = 1

    #             if isinstance(f.parent, GeoField):
    #                 recurse( f.parent, dmask=(dmask + footwall ), i=i+1 ) # recurse hangingwall objects
    #             if isinstance(f.parent2, GeoField):
    #                 recurse( f.parent2, dmask=(dmask + hangingwall), i=i+1 ) # recurse footwall objects

    #         elif (f.deformation is not None) and isinstance(f, FaultOffset): # this is a fault surface
    #             if topology:
    #                 iso = f.getIsovalue( f.deformation_args['contact'], offset=0 ) # fault surface
    #                 footwall = pred < iso
    #                 hangingwall = pred > iso

    #                 out['topology'][ footwall, i ] = -1
    #                 out['topology'][ hangingwall, i ] = 1

    #             if buffer:
    #                 i0 = f.getIsovalue( f.deformation_args['contact'], offset=-buffer ) # lower buffer
    #                 i1 = f.getIsovalue( f.deformation_args['contact'], offset=buffer ) # upper buffer
    #                 out['buffer'][ (pred >= min(i0, i1)) & (pred <= max(i0, i1)) & (out['buffer'] == 0) ] = f.eid

    #             if isinstance(f.parent, GeoField):
    #                 recurse( f.parent, dmask=dmask, i=i+1 ) # recurse older objects
    #         else:
    #             iso = None
            
    #         if surfaces: # compute isosurface meshes
    #             out['surfaces'][f.name] = {}
    #             for k in f.isosurfaces.keys():
    #                 if grid.ndim == 3: # 3D
    #                     verts, faces = grid.contour( pred, iso=f.getIsovalue(k))
    #                     out['surfaces'][f.name][k] = (verts, np.array(faces))
    #                     if isinstance(surfaces, str) or isinstance(surfaces, Path):
    #                         from curlew.io import savePLY
    #                         savePLY( Path( surfaces ) / str(f.name) / f'{str(k)}.ply',
    #                                 xyz = verts, faces = faces )
    #                 elif grid.ndim == 2: # 2D
    #                     contours = grid.contour( pred, iso=f.getIsovalue(k))
    #                     out['surfaces'][f.name][k] = contours

    #     # traverse from last event in model
    #     recurse( self.fields[-1], np.full( len(gxy), False ) )

    #     return out

    def __getitem__(self, index ):
        """Get fields by name (str) or SID (int)"""
        return self.getField( index )

    def getField( self, eid ):
        """
        Get the scalar field associated with the specified event ID (int) or name (str).

        Parameters
        ----------
        eid : int | str
            The event ID or field name to retrieve.

        Returns
        -------
        GeoField
            The scalar field instance associated with the specified event ID.
        """
        if isinstance(eid, str):
            for f in self.fields:
                if f.name == eid: return f
            assert False, f"A field with name {eid} does not exist in this model."
        else:
            return self.eidLookup.get( int(eid), None)

    def _getPositions(self, G, node, first_x=0, first_y=0, step_x=10, step_y=10, pos=None):
        """
        Recursively calculate the 2D positions of nodes in a hierarchical structure. Used when plotting
        the model tree as a 2D graph.

        Parameters
        ----------
        G : networkx.Digraph
             The directed graph containing the nodes to be positioned.
        node : str
            The current node for which to calculate the position. 
        first_x : int, optional
            The initial x-coordinate for the current node.
        first_y : int, optional
            The initial y-coordinate for the current node.
        step_x : int, optional
            The horizontal step size for moving to the right.
        step_y : int, optional
            The vertical step size for moving down.
        pos : dict, optional
            A dictionary to store the positions of nodes.

        Returns
        -------
        pos : dict
            A dictionary mapping each node to its (x, y) position.
        """
        if pos is None:
            pos = {}

        # Assign the position to the current node
        pos[node] = (first_x, first_y)

        # Get the children of the current node
        children = list(G.successors(node))

        if not children:
            return pos

        # If the node is a domain boundary, handle its children differently
        node_field = next((field for field in self.fields if field.name == node), None)
        if node_field.parent2 is not None and isinstance(node_field, GeoField):
            # Move to the right
            pos = self._getPositions(G, children[0], first_x + step_x, first_y, step_x, step_y, pos)
            # Move down
            pos = self._getPositions(G, children[1], first_x, first_y - step_y, step_x, step_y, pos)
        else:
            # For non-domain boundary nodes, move to the right
            pos = self._getPositions(G, children[0], first_x + step_x, first_y, step_x, step_y, pos)

        return pos

    def _repr_svg_(self):
        """
        Visualize the model tree of a GeoModel and return it as an SVG string.

        Parameters
        ----------
        None

        Returns
        -------
        str
            An SVG string representation of the visualized model tree.
        """
        # Create an empty graph
        try:
            import networkx as nx
        except:
            assert False, "Please install networkx using `pip install networkx`"
        graph = nx.DiGraph()

        domain_boundary_color = '#E35B0E'
        dilative_event_color = "#F0C419"
        generative_event_color = '#31B4C2'
        kinematic_event_color = "#A6340B"
        fixed_value_color = "#FAE8B6"

        for field in self.fields[::-1]:
            # Determine the color based on the event type
            color = None
            if field.parent2 is not None: # domain boundary
                color = domain_boundary_color
            elif field.overprint is not None and field.deformation is not None: # dilative event
                color = dilative_event_color
            elif field.overprint is not None: # generative event
                color = generative_event_color
            elif field.deformation is not None: # kinematic event
                color = kinematic_event_color
            graph.add_node(field.name, label=field.name, color=color)

            # Add edges
            if isinstance(field.parent, GeoField):
                graph.add_edge(field.name, field.parent.name)
            if isinstance(field.parent2, GeoField):
                graph.add_edge(field.name, field.parent2.name)
            if not isinstance(field.parent, GeoField) and field.parent is not None: # Handle fixed values
                graph.add_edge(field.name, str(field.parent))
            if not isinstance(field.parent2, GeoField) and field.parent2 is not None:
                graph.add_edge(field.name, str(field.parent2))

        # Plot
        try:
            import matplotlib.pyplot as plt
        except:
            assert False, "Please install matplotlib to use plotting tools.`"
        
        fig, ax = plt.subplots(1,1, figsize=(8, 4))
        pos = self._getPositions(graph, list(graph.nodes())[0], step_x=1, step_y=1)
        node_colors = [graph.nodes[node].get('color', fixed_value_color) for node in graph.nodes()]
        nx.draw(graph, pos, with_labels=True, arrows=True, node_size=2000,
                node_color=node_colors, font_size=8, ax=ax)
        
        # Legend
        legend_labels = {
            domain_boundary_color : 'Domain',
            dilative_event_color : 'Dilative',
            generative_event_color : 'Generative',
            kinematic_event_color : 'Kinematic',
            fixed_value_color : 'Fixed'
        }
        #ax_legend.axis('off')
        for color, label in legend_labels.items():
            ax.scatter([], [], color=color, label=label, s=200)
        ax.legend(loc='lower center', ncol=5)

        # Save the figure
        buffer = io.StringIO()
        fig.savefig(buffer, format='svg')
        plt.close(fig)
        svg = buffer.getvalue()
        buffer.close()

        return svg

A class representing a time-aware geological model and facilitating interactions with the underlying linked-list of GeoField instances (that represent each geological structure in the model).

Construct a GeoModel from a list of GeoFields.

Parameters

fields : list
A list of GeoField instances representing geological events, from oldest to youngest. This list can include domain boundaries if needed, but non-domain fields (e.g., faults, stratigraphy, etc.) should not be older than these.
transform : curlew.core.Transform
A Transform object defining the transform from global coordinates to model coordinates. This will be applied to all x when self.predict(x) is called, and can handle e.g., converting UTM to some model coordinate system. Defaults to an identity matrix (no transform).
grid : curlew.geometry.Grid | optional
An optional grid to associate with this GeoModel instance. This will set the M.grid variable but is not necessary (i.e. can be null; which is the default).
name : str | optional
A string name to associate with this GeoModel. Not really used, but can be useful :-)

Methods

def drill(self, start, end, step)
Expand source code
def drill( self, start, end, step ):
    """
    Evaluate the model along a line between start and end with an interval of step.

    Parameters
    -----------
    start : np.ndarray
        The start coordinate of the "drillhole"
    end : np.ndarray
        The end coordinate of the "drillhole"
    step : float
        The distance between points along this line

    Returns
    ---------
    drillholes : Geode
        A Geode instance containing the results given by evaluating the model along the drillhole.
    contacts : Geode
        A Geode instance containing the positions and orientations of contacts intersected along the drillhole.
    """
    dir = np.array(end) - np.array(start)
    length = np.linalg.norm(dir)
    dir = (dir / length)*step
    pos = np.array([start+dir*i for i in range( int(length / step) ) ])

    # evaluate model along drillholes
    g = self.predict( pos )

    # find contacts
    c = None
    g._contactMask = np.abs( np.diff( g.lithoID, prepend=g.lithoID[0] ) ) > 0
    if g._contactMask.any():
        cpos = pos[g._contactMask]
        c = self.predict( cpos, gradient=True ) # predict again, at the contact points only

    # return Geode
    return g, c

Evaluate the model along a line between start and end with an interval of step.

Parameters

start : np.ndarray
The start coordinate of the "drillhole"
end : np.ndarray
The end coordinate of the "drillhole"
step : float
The distance between points along this line

Returns

drillholes : Geode
A Geode instance containing the results given by evaluating the model along the drillhole.
contacts : Geode
A Geode instance containing the positions and orientations of contacts intersected along the drillhole.
def fit(self,
epochs,
learning_rate=None,
early_stop=(100, 0.0001),
best=True,
vb=True,
prefix='Training')
Expand source code
def fit(self, epochs, learning_rate=None, early_stop=(100, 1e-4), best=True, vb=True, prefix='Training'):
    """
    Train all GeoFields in this model to fit the specified constraints
    simultaneously.

    Parameters
    ----------
    epochs : int
        The number of epochs to train each GeoField for.
    learning_rate : float, optional
        Reset each GeoField's optimiser to the specified learning rate before training.
    early_stop : tuple,
        Tuple containing early stopping criterion. This should be (n,t) such that optimisation
        stops after n iterations with <= t improvement in the loss. Set to None to disable. Note 
        that early stopping is only applied if `best = True`. 
    best : bool, optional
        After training set the neural field weights to the best loss.
    vb : bool, optional
        Display a tqdm progress bar to monitor training.
    prefix : str, optional
        The prefix used for the tqdm progress bar.

    Returns
    -------
    loss : float
        The loss of the final (best if best=True) model state.
    details : dict
        A more detailed breakdown of the final loss. 
    """

    # set learning rate if specified
    if learning_rate is not None:
        for F in self.fields:
            F.set_rate( learning_rate )
    
    # setup progress bar
    bar = range(epochs)
    if vb:
        bar = tqdm(range(epochs), desc=prefix, bar_format="{desc}: {n_fmt}/{total_fmt}|{postfix}")

    # iterate
    out = {}
    #best_state = []
    best_loss = np.inf
    best_count = 0
    eps = 0
    if early_stop is not None:
        eps = early_stop[1]
    for epoch in bar:
        loss = 0
        for F in self.fields[::-1]:
            ll, details = F.loss() # compute loss for this field
            loss = loss + ll # accumulate loss
            out.update(details) # store for output

        # also add forward (property) reconstruction loss
        # if self.forward is not None:
        #     pp = self.forward.C.pp # position of property constraints
        #     pv = self.forward.C.pv # value of property constraints
        #     spred = self.fields[-1].predict(pp, combine=True, to_numpy=False) # automatically recursed back throught the linked list.
        #     # One Hot encoding
        #     if self.forward.H.one_hot:
        #         one_hot_encoder = torch.nn.functional.one_hot((spred[:, 1] - 1).long(), num_classes=len(self.fields))
        #         encoded_spred = one_hot_encoder * spred[:, 0][:, None]
        #         ppred = self.forward( encoded_spred )
        #     else:
        #         ppred = self.forward( spred ) # generate property predictions
        #     prop_loss = self.forward.loss_func( ppred, pv ) # compute loss
        #     if isinstance( self.forward.H.prop_loss, str):
        #         self.forward.H.prop_loss = float(self.forward.H.prop_loss) / prop_loss.item()
        #     loss = loss + self.forward.H.prop_loss * prop_loss
        #     out['forward'] = (prop_loss.item(),{})

        # store best state(s)
        if (loss.item() < (best_loss+eps)):
            #best_state = [ copy.deepcopy( F.field.state_dict()  ) for F in self.fields ]
            best_loss = loss.item()
            best_count = 0
        else:
            best_count += 1

        # early stopping
        if (early_stop is not None) and (best_count > early_stop[0]):
            break

        # update progress
        if vb:
            bar.set_postfix({ k : v[0] for k,v in out.items() })

        self.zero() # zero gradients
        loss.backward() # backprop losses
        self.step()

    # set best state
    #if best_state:
    #    for i,F in enumerate(self.fields):
    #        F.field.load_state_dict(best_state[i])

    # return
    return loss.item(), out

Train all GeoFields in this model to fit the specified constraints simultaneously.

Parameters

epochs : int
The number of epochs to train each GeoField for.
learning_rate : float, optional
Reset each GeoField's optimiser to the specified learning rate before training.
early_stop : tuple,
Tuple containing early stopping criterion. This should be (n,t) such that optimisation stops after n iterations with <= t improvement in the loss. Set to None to disable. Note that early stopping is only applied if best = True.
best : bool, optional
After training set the neural field weights to the best loss.
vb : bool, optional
Display a tqdm progress bar to monitor training.
prefix : str, optional
The prefix used for the tqdm progress bar.

Returns

loss : float
The loss of the final (best if best=True) model state.
details : dict
A more detailed breakdown of the final loss.
def freeze(self, name=None, geometry=True, params=False)
Expand source code
def freeze( self, name=None, geometry=True, params=False ):
    """
    Freeze the specified field or parameter. Used to e.g., optimise
    fault offset while keeping fault geometry fixed.

    Parameters
    ------------
    name, str | GeoField | list:
        The name of the GeoField to freeze. Can also be a list of names or instances. If None, 
        the specified freeze will be applied to all GeoFields in this model. Use `'forward'` to 
        address any defined forward model.
    geometry : bool
        True if the geometry of the specified GeoField should be frozen. Default is True. 
    params : bool
        True if other parameters (e.g., fault slip) associated with the specified GeoFields should be frozen. Default is False.
    """
    if name is None:
        name = [f for f in self.fields] # apply to all
    if not isinstance(name, list) or isinstance(name, tuple):
        name = [name]
    for f in name:
        if isinstance(f, str) or isinstance(f, int):
            f = self[f] # get field by name or ID
        f.field.frozen = geometry # freeze geometry?
        if f.deformation is not None: # freeze potentially learnable properties?
            f.deformation.frozen = params
        if f.propertyField is not None: # freeze potentially learnable properties?
            f.propertyField.frozen = params
        if f.overprint is not None: # freeze potentially learnable properties?
            f.overprint.frozen = params

Freeze the specified field or parameter. Used to e.g., optimise fault offset while keeping fault geometry fixed.

Parameters

name, str | GeoField | list:
The name of the GeoField to freeze. Can also be a list of names or instances. If None,
the specified freeze will be applied to all GeoFields in this model. Use 'forward' to
address any defined forward model.
geometry : bool
True if the geometry of the specified GeoField should be frozen. Default is True.
params : bool
True if other parameters (e.g., fault slip) associated with the specified GeoFields should be frozen. Default is False.
def getField(self, eid)
Expand source code
def getField( self, eid ):
    """
    Get the scalar field associated with the specified event ID (int) or name (str).

    Parameters
    ----------
    eid : int | str
        The event ID or field name to retrieve.

    Returns
    -------
    GeoField
        The scalar field instance associated with the specified event ID.
    """
    if isinstance(eid, str):
        for f in self.fields:
            if f.name == eid: return f
        assert False, f"A field with name {eid} does not exist in this model."
    else:
        return self.eidLookup.get( int(eid), None)

Get the scalar field associated with the specified event ID (int) or name (str).

Parameters

eid : int | str
The event ID or field name to retrieve.

Returns

GeoField
The scalar field instance associated with the specified event ID.
def predict(self, x: numpy.ndarray, coords='global', **kwargs)
Expand source code
def predict(self, x : np.ndarray, coords="global", **kwargs):
    """
    Create model predictions at the specified points.

    Parameters
    ----------
    x : np.ndarray | torch.tensor | curlew.geometry.Grid
        An array of shape (N, input_dim) containing the coordinates at which to evaluate
        this GeoModel.
    coords : str
        Specify which coordinate system `x` is in. If `coords == "global"` (default), then any defined
        model transform will be applied (to derive model coordinates). If `coords=="model"` then this
        transform will not be applied.
    
    Keywords
    --------
    All keywords are passed directly to `GeoField.predict()`.

    Returns
    --------
    S : An array of shape (N,1) containig the predicted scalar values and corresponding GeoField
        that "created" them.
    """

    # update isosurface lookup (incase the defined isosurfaces have been changed)
    # build lithology lookup (to ensure lithologies from different fields get unique IDs)
    self.llookup = {}
    self.eidLookup = { f.eid : f for f in self.fields } # create a lookup table for translating event IDs to GeoField instances
    n=1 # start at 1, as -1 is 'undefined' and 0 is default for fields with no lithology defined.
    for F in self.fields:
        self.llookup[F.name] = n # potential lithology created by this field (e.g., constant fields)
        n = n + 1
        if F.overprint is not None:  # only relevant for generative (overprinting) events [ as these "create" new rocks ]
            for k in F.isosurfaces.keys():
                k = f"{F.name}_{k}" # build key using field name and lithology name
                assert k not in self.llookup, f"All isosurfaces in model must have unique names, but {k} is not unique!"
                self.llookup[k] = n # assign ID for this lithology
                n = n + 1 # increment ID
        F.llookup = self.llookup # link lookup to field so it is used during predict(...).
    
    grid = None
    if isinstance(x, Grid):
        grid = x
        x = grid.coords()

    # apply transform to x
    if "global" in coords.lower():
        x = self.T(x) # transform from world to model coordinates

    # generate predictions
    kwargs['to_numpy'] = kwargs.get('to_numpy', True)
    kwargs['combine'] = True # this is necessary....
    out = self.fields[-1].predict(x, **kwargs) # automatically recursed back throught the linked list.
    
    out.grid = grid
    if "global" in coords.lower():
        out.x = x # replace with global coords
        out.crs = "global"
     
    # return
    return out

Create model predictions at the specified points.

Parameters

x : np.ndarray | torch.tensor | curlew.geometry.Grid
An array of shape (N, input_dim) containing the coordinates at which to evaluate this GeoModel.
coords : str
Specify which coordinate system x is in. If coords == "global" (default), then any defined model transform will be applied (to derive model coordinates). If coords=="model" then this transform will not be applied.

Keywords

All keywords are passed directly to GeoField.predict().

Returns

S : An array of shape (N,1) containig the predicted scalar values and corresponding GeoField
that "created" them.
def prefit(self, epochs, **kwargs)
Expand source code
def prefit(self, epochs, **kwargs):
    """
    Train all GeoFields in this model to fit their respective constraints
    in isolation, starting with the youngest field.

    Parameters
    ----------
    epochs : int
        The number of epochs to train for.
    
    Keywords
    ----------
    All keywords are passed to `curlew.fields.NF.fit(...)`. These include:
    learning_rate : float, optional
        Reset each GeoField's optimiser to the specified learning rate before training.
    best : bool, optional
        After training set neural field weights to the best loss.
    vb : bool, optional
        Display a tqdm progress bar to monitor training.

    Returns
    -------
    loss : float
        The loss of the final (best if best=True) model state.
    details : dict
        A more detailed breakdown of the final loss. 
    """
    out = {}
    for F in self.fields[::-1]:
        _, loss = F.fit( epochs, prefix=F.name, **kwargs )
        out.update(loss) # add outputs
    return out

Train all GeoFields in this model to fit their respective constraints in isolation, starting with the youngest field.

Parameters

epochs : int
The number of epochs to train for.

Keywords

All keywords are passed to curlew.fields.NF.fit(…). These include: learning_rate : float, optional Reset each GeoField's optimiser to the specified learning rate before training. best : bool, optional After training set neural field weights to the best loss. vb : bool, optional Display a tqdm progress bar to monitor training.

Returns

loss : float
The loss of the final (best if best=True) model state.
details : dict
A more detailed breakdown of the final loss.
def step(self)
Expand source code
def step(self):
    """
    Step all (unfrozen) optimisers associated with the fields making
    up this model, and (potentially) other leanrned parameters (e.g., fault offsets). 
    """
    for f in self.fields:
        f.step()

Step all (unfrozen) optimisers associated with the fields making up this model, and (potentially) other leanrned parameters (e.g., fault offsets).

def zero(self)
Expand source code
def zero(self):
    """
    Zero all (unfrozen) optimisers associated with the neural fields and 
    other learned parameters (e.g., fault offsets) in this model. 
    """
    for f in self.fields:
        f.zero()

Zero all (unfrozen) optimisers associated with the neural fields and other learned parameters (e.g., fault offsets) in this model.