Module curlew.fields

Import core neural field types from other python files, and define the "base" NF class that these all inherit from.

Sub-modules

curlew.fields.analytical

Classes for defining analytical (rather than interpolated) implicit fields, as these can also be used to build geological models.

curlew.fields.fourier

Implement fourier-feature based neural fields for scalar potential representation.

curlew.fields.geoinr

TODO - implement neural field based on the GeoINR approach.

curlew.fields.initialisation

TODO - implement functions for pre-training (initialising) neural fields to either a planar geometry or to an analytical function (analytical …

curlew.fields.series

Implement series based neural fields for scalar potential representation. The strength of these fields is that they are an analytical representation …

Classes

class BaseAF (name: str = None,
input_dim: int = None,
output_dim: int = 1,
C: CSet = None,
H: HSet = None,
drift=0,
transform=None,
local=None,
seed: int = 42,
**kwargs)
Expand source code
class BaseAF(BaseSF):
    """
    Base class for all analytical fields (those implementing specific geometric implicit functions).
    """
    pass # this does nothing special! But is included to easily distinguish analytical, neural and interpolated fields

Base class for all analytical fields (those implementing specific geometric implicit functions).

Initialise a new scalar field.

Parameters

name : str
A (ideally unique) name for this neural field. Should typically match the name of the GeoField instance that uses this field. Defaults to the name of this class.
input_dim : int, optional
The dimensionality of the input space (e.g., 3 for [x, y, z], 2 for [x,y]). If None (default) then default_dim is used.
output_dim : int, optional
Dimensionality of the output (usually 1 for a scalar potential).
C : CSet
Constraint sent used for learned or interpolated fields. Default is None.
H : HSet
Hyperparameters used to tune the loss function for this NF. Default is None.
drift : int | float | BaseSF
A constant integer or float (to use a constant value as the drift), or another BaseSF instance (e.g., an AnalyticalField) that defines the trend/drift of this field. This trend/drift will be evaluated at each input coordinate and added to the output of the field during the forward call, meaning learnable fields (interpolators) learn a residual relative to this drift. Default is 0 (no drift).
transform : callable
A function that transforms input coordinates prior to evaulation. Must take exactly one argument as input (a tensor of positions) and return the transformed positions.
local : curlew.core.Transform, optional
A Transform object defining the transform from (possibly undeformed) field coordinates to the local coordinates passed into the neural or analytical field representing this scalar field. This will be applied during the forward call, and can be used to e.g., implement global anisotropy, or tweak the represented structures by adding a constant offset or rotation. Defaults to an identity matrix (no transform).
seed : callable, optional
A random seed to (optinally) use for any random operations, if child classess wish.

Keywords

All keywords are passed to the initField(…) function of the child class, to build the relevant neural architecture.

Ancestors

Subclasses

Inherited members

class BaseNF (name: str,
H: HSet,
C: CSet = None,
input_dim: int = None,
output_dim: int = 1,
transform=None,
seed=42,
vloss=MSELoss(),
scale=100.0,
**kwargs)
Expand source code
class BaseNF(BaseSF):
    """
    A generic base for neural field implementations that learn to translating input coordinates to implicit value (or values). See the other
    child classes in this module (e.g., fourier, geoinr, etc.) for specific implementations.
    """
    def __init__(
            self,
            name : str,
            H: HSet,
            C : CSet = None,
            input_dim: int = None,
            output_dim: int = 1,
            transform = None,
            seed = 42,
            vloss = nn.MSELoss(),
            scale = 1e2,
            **kwargs
        ):
            """
            Parameters
            ----------
            name : str
                A (ideally unique) name for this neural field. Should typically match the name of the GeoField instance that uses this field.
            H : HSet
                Hyperparameters used to tune the loss function for this NF.
            C : CSet, optinoal
                Constraint sent used when learning this implicit field. Default is None (can be set using `field.bind(...)`).
            input_dim : int, optional
                The dimensionality of the input space (e.g., 3 for (x, y, z)). If None (default), then `curlew.default_dim` will be used.
            output_dim : int, optional
                Dimensionality of the output (usually 1 for a scalar potential).
            transform : callable
                A function that transforms input coordinates prior to predictions. Must take exactly one argument as input (a tensor of positions) and return the transformed positions. 
            seed : callable, optional
                The random seed to use for any random operations.
            vloss : callable, optional
                The loss function to use for value fitting. Default is mean squared error (`nn.MSELoss()`).
            scale : float, optional
                A scaling factor to apply to outputs of the neural field, as often these struggle to learn functions with a large (>1) amplitude. Default is 1e2. 
                
                This value should be approximately equal to the expected range (max - min) of the scalar field that is being learned. It can be especially important when using a 
                drift (trend), as it determines the extent to which the model initialisation is determined by the drift. Larger values should allow the model to deviate farther from the trend.
                Also note that this term also tends to control the magnitude of residuals (to value or (in)equality constraints), so will also interact with the learning rate.
                
                N.B. The actual implementation of this scale depends on the neural field method being used.

            Keywords
            ---------
            All keywords are passed to the initField(...) function of the child class, to build the relevant
            neural architecture.
            """
            # initialise everything (including calling the initField class of the relevant child class)
            super().__init__(name=name, input_dim=input_dim, output_dim=output_dim, H=H, C=C, transform=transform, seed=seed, **kwargs)

            # store neural field specific properties
            self.closs = torch.nn.CosineSimilarity() # needed by some loss functions
            self.vloss = vloss # loss function to use for value fitting
            self.scale = scale
            # optional worst-pair retention for inequality losses across epochs (see reuse_worst_half in HSet)
            self._last_iq_worst_indices = None

    ## LEARNING
    def loss(self, transform=True) -> torch.Tensor:
        """
        Compute the loss associated with this neural field given its current state. The `transform` argument
        specifies if constraints need to be transformed from modern to paleo-coordinates before computing loss.
        """
        if self.C is None:
            assert False, "Scalar field has no constraints"

        # move these into local scope for clarity
        C = self.C 
        H = self.H

        # inititialize different loss parts
        L = {}
        for k in ['value_loss', 'grad_loss', 'ori_loss', 'thick_loss', 'mono_loss', 'flat_loss', 'iq_loss']:
            L[k] = 0

        # LOCAL LOSS FUNCTIONS
        # -----------------------------
        # Value Loss
        if (C.vp is not None) and (C.vv is not None) and (isinstance(H.value_loss, str) or (H.value_loss > 0)):
            v_pred = self(C.vp, transform=transform)
            L['value_loss'] = self.vloss( v_pred, C.vv[:,None] )

        # Gradient loss
        self.reset_mnorm() # reset accumulation
        # [ N.B. positions (and thus gradients) are in un-transformed coordinates ]
        if (C.gp is not None) and (isinstance(H.grad_loss, str) or (H.grad_loss > 0)):
            gv_pred = self.gradient(C.gp, normalize=True, transform=transform, accumulate=True, retain_graph=True, create_graph=True) # compute gradient direction 
            L['grad_loss'] = self.vloss(gv_pred, C.gv) # N.B. constraints orientation and younging direction

        # Orientation loss
        # [ N.B. positions (and thus gradients) are in un-transformed coordinates ]
        if (C.gop is not None) and (isinstance(H.ori_loss, str) or (H.ori_loss > 0)):
            gv_pred = self.gradient(C.gop, normalize=True, transform=transform, accumulate=True, retain_graph=True, create_graph=True) # compute gradient direction 
            L['ori_loss'] = torch.clamp( torch.mean( 1 - torch.abs( self.closs(gv_pred, C.gov ) ) ), min=1e-6 ) # N.B.: Orientation loss on its own fits a bit too well, numerical precision crashes avoided with the clamp - AVK

        # GLOBAL LOSS FUNCTIONS
        # -------------------------------
        reuse_frac = getattr(H, 'reuse_worst_half', 0) or 0 # set as 0 or None to disable
        if C.grid is not None:
            if transform:
                gridL = C.grid.draw(self.transform)  # sample transformed grid points
            else:
                gridL = C.grid.draw() # sample non-transformed grid points
            
            if  isinstance(H.thick_loss, str) or isinstance(H.mono_loss, str) or isinstance(H.flat_loss, str) or \
                (H.thick_loss > 0) or (H.mono_loss > 0) or (H.flat_loss > 0):

                # Numerically compute the Hessian for mono loss
                # (single batched gradient call over all shifted grids)
                d = self.input_dim
                N = gridL.shape[0]
                delta = C.delta
                offsets = torch.stack(C._offset, dim=0)  # (d, d)
                all_pos = gridL.unsqueeze(0) + offsets.unsqueeze(1)   # (d, N, d)
                all_neg = gridL.unsqueeze(0) - offsets.unsqueeze(1)   # (d, N, d)
                stacked_coords = torch.cat([all_pos, all_neg], dim=0).reshape(-1, d)  # (2*d*N, d)

                grad_all = self.gradient(
                    stacked_coords, normalize=False, transform=False,
                    accumulate=True, retain_graph=True, create_graph=True
                )  # (2*d*N, d)

                grad_pos = grad_all[: d * N].reshape(d, N, d)   # (d, N, d)
                grad_neg = grad_all[d * N :].reshape(d, N, d)   # (d, N, d)

                # Compute Hessian diagonal: ∂²φ/∂x_j² ≈ (∇φ(x+δ e_j)_j - ∇φ(x-δ e_j)_j) / (2*delta)
                hess_diag = (grad_pos - grad_neg).diagonal(dim1=0, dim2=2) / (2 * delta)  # (N, d)
                hess = torch.zeros((N, d, d), device=curlew.device, dtype=curlew.dtype)
                hess[:, range(d), range(d)] = hess_diag

                pnorm = torch.norm(grad_pos, dim=-1)   # (d, N)
                nnorm = torch.norm(grad_neg, dim=-1)   # (d, N)

                # if isinstance(H.mono_loss, str) or (H.mono_loss > 0):
                #     grad_pos_norm = grad_pos / (pnorm.unsqueeze(-1) + 1e-8)
                #     grad_neg_norm = grad_neg / (nnorm.unsqueeze(-1) + 1e-8)
                #     ndiv = (grad_pos_norm - grad_neg_norm).diagonal(dim1=0, dim2=2).sum(dim=1) / (2 * delta)

                if isinstance(H.thick_loss, str) or (H.thick_loss > 0):
                    pnorm_mean = torch.clip(pnorm.mean(dim=1, keepdim=True), 1e-8, torch.inf)
                    nnorm_mean = torch.clip(nnorm.mean(dim=1, keepdim=True), 1e-8, torch.inf)
                    L['thick_loss'] = ((1 - pnorm / pnorm_mean) ** 2).mean() + ((1 - nnorm / nnorm_mean) ** 2).mean()

                # compute derived thickness and monotonocity loss
                if isinstance(H.mono_loss, str) or (H.mono_loss > 0):
                    #L['mono_loss'] = torch.mean(ndiv**2) # (normalised) divergence should be close to 0
                    L['mono_loss'] = torch.mean(torch.abs(hess))
                #if isinstance(H.thick_loss, str) or (H.thick_loss > 0):
                #    # L['thick_loss'] = torch.mean( torch.linalg.det(hess)**2 ) # determinant should be close to 0 [ breaks in 2D, as the trace and determinant can't both be 0 unless all is 0!]
                #    L['thick_loss'] = L['thick_loss'] / (2*self.input_dim) # normalise to get average (doesn't change anything, but makes values easier to interpret)

                # Flatness Loss --  gradients everywhere parallel to trend
                if (isinstance(H.flat_loss, str) or (H.flat_loss > 0)) and (C.trend is not None):
                    if transform:
                        gv_at_grid_p = self.gradient(gridL, normalize=True, transform=self.transform, retain_graph=True, create_graph=True) # this requires gradients relative to modern coordinates! 
                    else:
                        gv_at_grid_p = self.gradient(gridL, normalize=True, transform=False, retain_graph=True, create_graph=True)
                    L['flat_loss'] = torch.mean((gv_at_grid_p - C.trend[None,:])**2) # "younging" direction
                    #flat_loss = (1 - self.closs( gv_at_grid_p, C.trend )).mean() # orientation only

        # inequality losses (single batched forward; start/end interleaved so reshape separates)
        if (C.iq is not None) and (isinstance(H.iq_loss, str) or (H.iq_loss > 0)):
            ns = C.iq[0]
            pts_list = []  # [s0_block, e0_block, s1_block, e1_block, ...]
            # compile inequality pairs to compute (half from cached "worst half" and half randomly drawn)
            six_list, eix_list = [], []
            for c, (start, end, iq) in enumerate(C.iq[1]):
                if reuse_frac > 0 and self._last_iq_worst_indices is not None and c < len(self._last_iq_worst_indices):
                    six_keep, eix_keep = self._last_iq_worst_indices[c]
                    n_keep = six_keep.shape[0]
                    n_new = ns - n_keep
                    six_new = torch.randint(0, start.shape[0], (n_new,), dtype=torch.int, device=curlew.device)
                    eix_new = torch.randint(0, end.shape[0], (n_new,), dtype=torch.int, device=curlew.device)
                    six = torch.cat([six_keep, six_new])
                    eix = torch.cat([eix_keep, eix_new])
                else:
                    six = torch.randint(0, start.shape[0], (ns,), dtype=torch.int, device=curlew.device)
                    eix = torch.randint(0, end.shape[0], (ns,), dtype=torch.int, device=curlew.device)
                six_list.append(six)
                eix_list.append(eix)
                pts_list.append(start[six, :])
                pts_list.append(end[eix, :])
                
            # evaluate model and compute differences between sampled pairs
            all_pts = torch.cat(pts_list, dim=0)
            n_iq = len(pts_list) // 2
            all_vals = self(all_pts, transform=transform).flatten()
            all_vals = all_vals.view(n_iq, 2, ns)
            start_vals = all_vals[:, 0, :].reshape(-1)
            end_vals = all_vals[:, 1, :].reshape(-1)

            # Apply clamp to differences to only keep differences that violate each inequality.
            # N.B. C._iq_low_clamp / C._iq_high_clamp are pre-allocated in bind()
            delta_all = torch.clamp(start_vals - end_vals, C._iq_low_clamp, C._iq_high_clamp)**2
            L['iq_loss'] = torch.mean(delta_all)

            # update cache of inequality pairs with loss above threshold for next epoch (reuse_worst_half)
            if reuse_frac > 0 and ns > 0:
                self._last_iq_worst_indices = []
                for c in range(n_iq):
                    chunk = delta_all[c * ns : (c + 1) * ns]
                    threshold = reuse_frac * chunk.mean()
                    keep_ix = (chunk >= threshold).nonzero(as_tuple=True)[0]
                    if keep_ix.numel() == 0:
                        keep_ix = chunk.argmax().unsqueeze(0)
                    self._last_iq_worst_indices.append((
                        six_list[c][keep_ix].detach(),
                        eix_list[c][keep_ix].detach(),
                    ))

        # Dynamically adjust task weights based on the inverse of real-time loss values.
        # (this ignores the magnitude of each loss term, but preserves it's gradient direction,
        # and is a hacky but sometimes useful way to balance multi-task losses)
        if H.use_dynamic_loss_weighting:
            for k,v in L.items():
                if v > 0:
                    L[k] = 1 / v.item()
        
        # parse loss hyperparameters and aggregate to get combined loss
        out = { self.name : [0,{}] }
        total_loss = 0
        for k,v in L.items():
            h = H.__getattribute__(k) # hyperparmeter weight
            if isinstance(h, str):
                h = float(h) * (1/v).item() if v > 0 else 0.0
                H.__setattr__(k, h )
            
            if (h is not None) and (h > 0) and (v > 0):
                s = h*v # scaled loss term
                # store loss for debugging / reporting
                out[self.name][1][k] = (s.item(), v.item())
                
                # throw away loss magnitude (just use sign). Can be useful for multi-objective optimisation
                if (H.use_dynamic_loss_weighting):
                    s = 1 / s.item()

                total_loss = total_loss + s # aggregate loss!
            
        # store total loss too
        out[self.name][0] = total_loss.item()

        # done! 
        return total_loss, out

    def fit(self, epochs, 
                 C : CSet = None, 
                 learning_rate : float = None, 
                 early_stop : tuple = (100,1e-4), 
                 transform : bool = True, 
                 best : bool = True, 
                 vb : bool = True, 
                 prefix : str = 'Training',
                 opt : list = []):
        """
        Train this neural field to fit the specified constraints.

        Parameters
        ----------
        epochs : int
            The number of epochs to train for.
        C : CSet, optional
            The set of constraints to fit this field to. If None, the previously
            bound constraint set will be used.
        learning_rate : float, optional
            Reset this NF's optimiser to the specified learning rate before training.
        early_stop : tuple,
            Tuple containing early stopping criterion. This should be (n,t) such that optimisation
            stops after n iterations with <= t improvement in the loss. Set to None to disable. Note 
            that early stopping is only applied if `best = True`. 
        transform : bool, optional
            True (default) if constraints (C) is in modern coordinates that need to be transformed during fitting. If False, 
            C is considered to have already been transformed to paleo-coordinates. Note that this can be problematic if rotations
            occur (e.g. of gradient constraints!).
        best : bool, optional
            After training set the neural field weights to the best loss.
        vb : bool, optional
            Display a tqdm progress bar to monitor training.
        prefix : str, optional
            The prefix used for the tqdm progress bar.
        opt : list, optional
            An optional list of additional optimisers to include in the training loop (zero() and step() will be called
            on these at the same time as the optimiser used for this NF's internal learnable parameters). Used to allow
            e.g., learnable fault offset.
        Returns
        -------
        loss : float
            The loss of the final (best if best=True) model state.
        details : dict
            A more detailed breakdown of the final loss. 
        """
        # set learning rate if needed
        if learning_rate is not None:
            self.set_rate(learning_rate)

        # bind the constraints
        if C is not None:
            self.bind(C)

        # Compile loss for faster repeated evaluation (PyTorch 2+)
        #if compile_loss and hasattr(torch, "compile"):
        #    _loss_fn = torch.compile(self.loss, mode="reduce-overhead")
        #else:
        _loss_fn = self.loss

        # store best state
        best_loss = np.inf
        best_loss_ = None
        best_state = None

        # for early stopping
        best_count = 0
        eps = 0
        if early_stop is not None:
            eps = early_stop[1]

        # setup progress bar
        bar = range(epochs)
        if vb:
            bar = tqdm(range(epochs), desc=prefix, bar_format="{desc}: {n_fmt}/{total_fmt}|{postfix}")
        for epoch in bar:
            loss, details = _loss_fn(transform=transform)

            if (loss.item() < (best_loss + eps)): # update best state
                best_loss = loss.item()
                best_loss_ = details
                best_state = {k: v.detach().clone() for k, v in self.state_dict().items()}
                best_count = 0
            else: # not necessarily the best; but keep for return
                if best_state == None:
                    best_loss = loss.item()
                    best_loss_ = details
                best_count += 1

            # early stopping?
            if (early_stop is not None) and (best_count > early_stop[0]):
                break

            if vb: # update progress bar
                bar.set_postfix({ k : v[0] for k,v in details[self.name][1].items() })

            # backward pass and update
            self.zero()
            for o in opt:
                if o is not None: o.zero() # can often be None; ignore in that case.
            loss.backward(retain_graph=False)
            self.step()
            for o in opt:
                if o is not None: o.step()

        if best:
            self.load_state_dict(best_state)

        return best_loss, best_loss_ # return summed and detailed loss

A generic base for neural field implementations that learn to translating input coordinates to implicit value (or values). See the other child classes in this module (e.g., fourier, geoinr, etc.) for specific implementations.

Parameters

name : str
A (ideally unique) name for this neural field. Should typically match the name of the GeoField instance that uses this field.
H : HSet
Hyperparameters used to tune the loss function for this NF.
C : CSet, optinoal
Constraint sent used when learning this implicit field. Default is None (can be set using field.bind(…)).
input_dim : int, optional
The dimensionality of the input space (e.g., 3 for (x, y, z)). If None (default), then default_dim will be used.
output_dim : int, optional
Dimensionality of the output (usually 1 for a scalar potential).
transform : callable
A function that transforms input coordinates prior to predictions. Must take exactly one argument as input (a tensor of positions) and return the transformed positions.
seed : callable, optional
The random seed to use for any random operations.
vloss : callable, optional
The loss function to use for value fitting. Default is mean squared error (nn.MSELoss()).
scale : float, optional

A scaling factor to apply to outputs of the neural field, as often these struggle to learn functions with a large (>1) amplitude. Default is 1e2.

This value should be approximately equal to the expected range (max - min) of the scalar field that is being learned. It can be especially important when using a drift (trend), as it determines the extent to which the model initialisation is determined by the drift. Larger values should allow the model to deviate farther from the trend. Also note that this term also tends to control the magnitude of residuals (to value or (in)equality constraints), so will also interact with the learning rate.

N.B. The actual implementation of this scale depends on the neural field method being used.

Keywords

All keywords are passed to the initField(…) function of the child class, to build the relevant neural architecture.

Ancestors

Subclasses

Methods

def fit(self,
epochs,
C: CSet = None,
learning_rate: float = None,
early_stop: tuple = (100, 0.0001),
transform: bool = True,
best: bool = True,
vb: bool = True,
prefix: str = 'Training',
opt: list = [])
Expand source code
def fit(self, epochs, 
             C : CSet = None, 
             learning_rate : float = None, 
             early_stop : tuple = (100,1e-4), 
             transform : bool = True, 
             best : bool = True, 
             vb : bool = True, 
             prefix : str = 'Training',
             opt : list = []):
    """
    Train this neural field to fit the specified constraints.

    Parameters
    ----------
    epochs : int
        The number of epochs to train for.
    C : CSet, optional
        The set of constraints to fit this field to. If None, the previously
        bound constraint set will be used.
    learning_rate : float, optional
        Reset this NF's optimiser to the specified learning rate before training.
    early_stop : tuple,
        Tuple containing early stopping criterion. This should be (n,t) such that optimisation
        stops after n iterations with <= t improvement in the loss. Set to None to disable. Note 
        that early stopping is only applied if `best = True`. 
    transform : bool, optional
        True (default) if constraints (C) is in modern coordinates that need to be transformed during fitting. If False, 
        C is considered to have already been transformed to paleo-coordinates. Note that this can be problematic if rotations
        occur (e.g. of gradient constraints!).
    best : bool, optional
        After training set the neural field weights to the best loss.
    vb : bool, optional
        Display a tqdm progress bar to monitor training.
    prefix : str, optional
        The prefix used for the tqdm progress bar.
    opt : list, optional
        An optional list of additional optimisers to include in the training loop (zero() and step() will be called
        on these at the same time as the optimiser used for this NF's internal learnable parameters). Used to allow
        e.g., learnable fault offset.
    Returns
    -------
    loss : float
        The loss of the final (best if best=True) model state.
    details : dict
        A more detailed breakdown of the final loss. 
    """
    # set learning rate if needed
    if learning_rate is not None:
        self.set_rate(learning_rate)

    # bind the constraints
    if C is not None:
        self.bind(C)

    # Compile loss for faster repeated evaluation (PyTorch 2+)
    #if compile_loss and hasattr(torch, "compile"):
    #    _loss_fn = torch.compile(self.loss, mode="reduce-overhead")
    #else:
    _loss_fn = self.loss

    # store best state
    best_loss = np.inf
    best_loss_ = None
    best_state = None

    # for early stopping
    best_count = 0
    eps = 0
    if early_stop is not None:
        eps = early_stop[1]

    # setup progress bar
    bar = range(epochs)
    if vb:
        bar = tqdm(range(epochs), desc=prefix, bar_format="{desc}: {n_fmt}/{total_fmt}|{postfix}")
    for epoch in bar:
        loss, details = _loss_fn(transform=transform)

        if (loss.item() < (best_loss + eps)): # update best state
            best_loss = loss.item()
            best_loss_ = details
            best_state = {k: v.detach().clone() for k, v in self.state_dict().items()}
            best_count = 0
        else: # not necessarily the best; but keep for return
            if best_state == None:
                best_loss = loss.item()
                best_loss_ = details
            best_count += 1

        # early stopping?
        if (early_stop is not None) and (best_count > early_stop[0]):
            break

        if vb: # update progress bar
            bar.set_postfix({ k : v[0] for k,v in details[self.name][1].items() })

        # backward pass and update
        self.zero()
        for o in opt:
            if o is not None: o.zero() # can often be None; ignore in that case.
        loss.backward(retain_graph=False)
        self.step()
        for o in opt:
            if o is not None: o.step()

    if best:
        self.load_state_dict(best_state)

    return best_loss, best_loss_ # return summed and detailed loss

Train this neural field to fit the specified constraints.

Parameters

epochs : int
The number of epochs to train for.
C : CSet, optional
The set of constraints to fit this field to. If None, the previously bound constraint set will be used.
learning_rate : float, optional
Reset this NF's optimiser to the specified learning rate before training.
early_stop : tuple,
Tuple containing early stopping criterion. This should be (n,t) such that optimisation stops after n iterations with <= t improvement in the loss. Set to None to disable. Note that early stopping is only applied if best = True.
transform : bool, optional
True (default) if constraints (C) is in modern coordinates that need to be transformed during fitting. If False, C is considered to have already been transformed to paleo-coordinates. Note that this can be problematic if rotations occur (e.g. of gradient constraints!).
best : bool, optional
After training set the neural field weights to the best loss.
vb : bool, optional
Display a tqdm progress bar to monitor training.
prefix : str, optional
The prefix used for the tqdm progress bar.
opt : list, optional
An optional list of additional optimisers to include in the training loop (zero() and step() will be called on these at the same time as the optimiser used for this NF's internal learnable parameters). Used to allow e.g., learnable fault offset.

Returns

loss : float
The loss of the final (best if best=True) model state.
details : dict
A more detailed breakdown of the final loss.
def loss(self, transform=True) ‑> torch.Tensor
Expand source code
def loss(self, transform=True) -> torch.Tensor:
    """
    Compute the loss associated with this neural field given its current state. The `transform` argument
    specifies if constraints need to be transformed from modern to paleo-coordinates before computing loss.
    """
    if self.C is None:
        assert False, "Scalar field has no constraints"

    # move these into local scope for clarity
    C = self.C 
    H = self.H

    # inititialize different loss parts
    L = {}
    for k in ['value_loss', 'grad_loss', 'ori_loss', 'thick_loss', 'mono_loss', 'flat_loss', 'iq_loss']:
        L[k] = 0

    # LOCAL LOSS FUNCTIONS
    # -----------------------------
    # Value Loss
    if (C.vp is not None) and (C.vv is not None) and (isinstance(H.value_loss, str) or (H.value_loss > 0)):
        v_pred = self(C.vp, transform=transform)
        L['value_loss'] = self.vloss( v_pred, C.vv[:,None] )

    # Gradient loss
    self.reset_mnorm() # reset accumulation
    # [ N.B. positions (and thus gradients) are in un-transformed coordinates ]
    if (C.gp is not None) and (isinstance(H.grad_loss, str) or (H.grad_loss > 0)):
        gv_pred = self.gradient(C.gp, normalize=True, transform=transform, accumulate=True, retain_graph=True, create_graph=True) # compute gradient direction 
        L['grad_loss'] = self.vloss(gv_pred, C.gv) # N.B. constraints orientation and younging direction

    # Orientation loss
    # [ N.B. positions (and thus gradients) are in un-transformed coordinates ]
    if (C.gop is not None) and (isinstance(H.ori_loss, str) or (H.ori_loss > 0)):
        gv_pred = self.gradient(C.gop, normalize=True, transform=transform, accumulate=True, retain_graph=True, create_graph=True) # compute gradient direction 
        L['ori_loss'] = torch.clamp( torch.mean( 1 - torch.abs( self.closs(gv_pred, C.gov ) ) ), min=1e-6 ) # N.B.: Orientation loss on its own fits a bit too well, numerical precision crashes avoided with the clamp - AVK

    # GLOBAL LOSS FUNCTIONS
    # -------------------------------
    reuse_frac = getattr(H, 'reuse_worst_half', 0) or 0 # set as 0 or None to disable
    if C.grid is not None:
        if transform:
            gridL = C.grid.draw(self.transform)  # sample transformed grid points
        else:
            gridL = C.grid.draw() # sample non-transformed grid points
        
        if  isinstance(H.thick_loss, str) or isinstance(H.mono_loss, str) or isinstance(H.flat_loss, str) or \
            (H.thick_loss > 0) or (H.mono_loss > 0) or (H.flat_loss > 0):

            # Numerically compute the Hessian for mono loss
            # (single batched gradient call over all shifted grids)
            d = self.input_dim
            N = gridL.shape[0]
            delta = C.delta
            offsets = torch.stack(C._offset, dim=0)  # (d, d)
            all_pos = gridL.unsqueeze(0) + offsets.unsqueeze(1)   # (d, N, d)
            all_neg = gridL.unsqueeze(0) - offsets.unsqueeze(1)   # (d, N, d)
            stacked_coords = torch.cat([all_pos, all_neg], dim=0).reshape(-1, d)  # (2*d*N, d)

            grad_all = self.gradient(
                stacked_coords, normalize=False, transform=False,
                accumulate=True, retain_graph=True, create_graph=True
            )  # (2*d*N, d)

            grad_pos = grad_all[: d * N].reshape(d, N, d)   # (d, N, d)
            grad_neg = grad_all[d * N :].reshape(d, N, d)   # (d, N, d)

            # Compute Hessian diagonal: ∂²φ/∂x_j² ≈ (∇φ(x+δ e_j)_j - ∇φ(x-δ e_j)_j) / (2*delta)
            hess_diag = (grad_pos - grad_neg).diagonal(dim1=0, dim2=2) / (2 * delta)  # (N, d)
            hess = torch.zeros((N, d, d), device=curlew.device, dtype=curlew.dtype)
            hess[:, range(d), range(d)] = hess_diag

            pnorm = torch.norm(grad_pos, dim=-1)   # (d, N)
            nnorm = torch.norm(grad_neg, dim=-1)   # (d, N)

            # if isinstance(H.mono_loss, str) or (H.mono_loss > 0):
            #     grad_pos_norm = grad_pos / (pnorm.unsqueeze(-1) + 1e-8)
            #     grad_neg_norm = grad_neg / (nnorm.unsqueeze(-1) + 1e-8)
            #     ndiv = (grad_pos_norm - grad_neg_norm).diagonal(dim1=0, dim2=2).sum(dim=1) / (2 * delta)

            if isinstance(H.thick_loss, str) or (H.thick_loss > 0):
                pnorm_mean = torch.clip(pnorm.mean(dim=1, keepdim=True), 1e-8, torch.inf)
                nnorm_mean = torch.clip(nnorm.mean(dim=1, keepdim=True), 1e-8, torch.inf)
                L['thick_loss'] = ((1 - pnorm / pnorm_mean) ** 2).mean() + ((1 - nnorm / nnorm_mean) ** 2).mean()

            # compute derived thickness and monotonocity loss
            if isinstance(H.mono_loss, str) or (H.mono_loss > 0):
                #L['mono_loss'] = torch.mean(ndiv**2) # (normalised) divergence should be close to 0
                L['mono_loss'] = torch.mean(torch.abs(hess))
            #if isinstance(H.thick_loss, str) or (H.thick_loss > 0):
            #    # L['thick_loss'] = torch.mean( torch.linalg.det(hess)**2 ) # determinant should be close to 0 [ breaks in 2D, as the trace and determinant can't both be 0 unless all is 0!]
            #    L['thick_loss'] = L['thick_loss'] / (2*self.input_dim) # normalise to get average (doesn't change anything, but makes values easier to interpret)

            # Flatness Loss --  gradients everywhere parallel to trend
            if (isinstance(H.flat_loss, str) or (H.flat_loss > 0)) and (C.trend is not None):
                if transform:
                    gv_at_grid_p = self.gradient(gridL, normalize=True, transform=self.transform, retain_graph=True, create_graph=True) # this requires gradients relative to modern coordinates! 
                else:
                    gv_at_grid_p = self.gradient(gridL, normalize=True, transform=False, retain_graph=True, create_graph=True)
                L['flat_loss'] = torch.mean((gv_at_grid_p - C.trend[None,:])**2) # "younging" direction
                #flat_loss = (1 - self.closs( gv_at_grid_p, C.trend )).mean() # orientation only

    # inequality losses (single batched forward; start/end interleaved so reshape separates)
    if (C.iq is not None) and (isinstance(H.iq_loss, str) or (H.iq_loss > 0)):
        ns = C.iq[0]
        pts_list = []  # [s0_block, e0_block, s1_block, e1_block, ...]
        # compile inequality pairs to compute (half from cached "worst half" and half randomly drawn)
        six_list, eix_list = [], []
        for c, (start, end, iq) in enumerate(C.iq[1]):
            if reuse_frac > 0 and self._last_iq_worst_indices is not None and c < len(self._last_iq_worst_indices):
                six_keep, eix_keep = self._last_iq_worst_indices[c]
                n_keep = six_keep.shape[0]
                n_new = ns - n_keep
                six_new = torch.randint(0, start.shape[0], (n_new,), dtype=torch.int, device=curlew.device)
                eix_new = torch.randint(0, end.shape[0], (n_new,), dtype=torch.int, device=curlew.device)
                six = torch.cat([six_keep, six_new])
                eix = torch.cat([eix_keep, eix_new])
            else:
                six = torch.randint(0, start.shape[0], (ns,), dtype=torch.int, device=curlew.device)
                eix = torch.randint(0, end.shape[0], (ns,), dtype=torch.int, device=curlew.device)
            six_list.append(six)
            eix_list.append(eix)
            pts_list.append(start[six, :])
            pts_list.append(end[eix, :])
            
        # evaluate model and compute differences between sampled pairs
        all_pts = torch.cat(pts_list, dim=0)
        n_iq = len(pts_list) // 2
        all_vals = self(all_pts, transform=transform).flatten()
        all_vals = all_vals.view(n_iq, 2, ns)
        start_vals = all_vals[:, 0, :].reshape(-1)
        end_vals = all_vals[:, 1, :].reshape(-1)

        # Apply clamp to differences to only keep differences that violate each inequality.
        # N.B. C._iq_low_clamp / C._iq_high_clamp are pre-allocated in bind()
        delta_all = torch.clamp(start_vals - end_vals, C._iq_low_clamp, C._iq_high_clamp)**2
        L['iq_loss'] = torch.mean(delta_all)

        # update cache of inequality pairs with loss above threshold for next epoch (reuse_worst_half)
        if reuse_frac > 0 and ns > 0:
            self._last_iq_worst_indices = []
            for c in range(n_iq):
                chunk = delta_all[c * ns : (c + 1) * ns]
                threshold = reuse_frac * chunk.mean()
                keep_ix = (chunk >= threshold).nonzero(as_tuple=True)[0]
                if keep_ix.numel() == 0:
                    keep_ix = chunk.argmax().unsqueeze(0)
                self._last_iq_worst_indices.append((
                    six_list[c][keep_ix].detach(),
                    eix_list[c][keep_ix].detach(),
                ))

    # Dynamically adjust task weights based on the inverse of real-time loss values.
    # (this ignores the magnitude of each loss term, but preserves it's gradient direction,
    # and is a hacky but sometimes useful way to balance multi-task losses)
    if H.use_dynamic_loss_weighting:
        for k,v in L.items():
            if v > 0:
                L[k] = 1 / v.item()
    
    # parse loss hyperparameters and aggregate to get combined loss
    out = { self.name : [0,{}] }
    total_loss = 0
    for k,v in L.items():
        h = H.__getattribute__(k) # hyperparmeter weight
        if isinstance(h, str):
            h = float(h) * (1/v).item() if v > 0 else 0.0
            H.__setattr__(k, h )
        
        if (h is not None) and (h > 0) and (v > 0):
            s = h*v # scaled loss term
            # store loss for debugging / reporting
            out[self.name][1][k] = (s.item(), v.item())
            
            # throw away loss magnitude (just use sign). Can be useful for multi-objective optimisation
            if (H.use_dynamic_loss_weighting):
                s = 1 / s.item()

            total_loss = total_loss + s # aggregate loss!
        
    # store total loss too
    out[self.name][0] = total_loss.item()

    # done! 
    return total_loss, out

Compute the loss associated with this neural field given its current state. The transform argument specifies if constraints need to be transformed from modern to paleo-coordinates before computing loss.

Inherited members

class BaseSF (name: str = None,
input_dim: int = None,
output_dim: int = 1,
C: CSet = None,
H: HSet = None,
drift=0,
transform=None,
local=None,
seed: int = 42,
**kwargs)
Expand source code
class BaseSF(LearnableBase):
    """
    Base class for all implicit (scalar) fields, including interpolated, learned or analytical fields.
    """
    
    level = np.inf
    """
    The level of reconstruction to apply when evaluating this field. If "0" then only the drift/trend is returned. 
    If "inf" then the highest level of detail is returned. 
    
    Values between 0 and np.inf will be treated differently by different field types. Default is np.inf.
    """

    def __init__(self, name : str = None, 
                       input_dim: int = None,
                       output_dim: int = 1,
                       C: CSet = None,
                       H: HSet = None,
                       drift = 0,
                       transform = None,
                       local = None,
                       seed : int = 42, **kwargs ):
        """
        Initialise a new scalar field.

        Parameters
        ----------
        name : str
            A (ideally unique) name for this neural field. Should typically match the name of the GeoField instance that uses this field. Defaults
            to the name of this class.
        input_dim : int, optional
            The dimensionality of the input space (e.g., 3 for [x, y, z], 2 for [x,y]). If None (default) then `curlew.default_dim` is used.
        output_dim : int, optional
            Dimensionality of the output (usually 1 for a scalar potential).
        C : CSet
            Constraint sent used for learned or interpolated fields. Default is None.
        H : HSet
            Hyperparameters used to tune the loss function for this NF. Default is None.
        drift : int | float | BaseSF
            A constant integer or float (to use a constant value as the drift), or another BaseSF instance (e.g., an AnalyticalField) that
            defines the trend/drift of this field. This trend/drift will be evaluated at each input coordinate and added to the output of the 
            field during the forward call, meaning learnable fields (interpolators) learn a residual relative to this drift. Default is 0 (no drift).
        transform : callable
            A function that transforms input coordinates prior to evaulation. Must take exactly one argument as input (a tensor of positions) and return the transformed positions. 
        local : `curlew.core.Transform`, optional
            A Transform object defining the transform from (possibly undeformed) field coordinates to the local coordinates passed into the neural or analytical field
            representing this scalar field. This will be applied during the forward call, and can be used to e.g., implement global anisotropy, or tweak the 
            represented structures by adding a constant offset or rotation. Defaults to an identity matrix (no transform).
        seed : callable, optional
            A random seed to (optinally) use for any random operations, if child classess wish.

        Keywords
        ---------
        All keywords are passed to the initField(...) function of the child class, to build the relevant
        neural architecture.

        """
        super().__init__()
        self.name = name
        if self.name is None:
            self.name = str(type(self).__name__) # default name is the name of the field type
        if input_dim is None:
            self.input_dim = curlew.default_dim
        else:
            self.input_dim = input_dim
        self.output_dim = output_dim
        self.transform = transform
        self.seed = seed # seed to use for any random operations
        self.C = None # will contain constraints if bound
        self.H = None
        if H is not None:
            self.H = H.copy() # will contain hyperparameters if bound
        if C is not None:
            self.bind(C)
        
        self.drift = drift # can be a constant or another field; evaluated during forward pass
        self.mnorm = 0 # cache the average field gradient (can be useful for quick/rough normalisation)
        self.nnorm = 0 # number of evaluations used to compute average gradient

        if local is None:
            self.T = Transform(self.input_dim)
        else:
            self.T = local

        self.initField( **kwargs ) # call child class init to build the network
    
    def initField(self, **kwargs):
        """
        Build the internal structure of this implicit field. This should be implemented
        by child classes.
        """
        assert False, "BaseNF does not implement initField()"

    def evaluate(x):
        """
        Evaluate this field at the specified input coordinates. This should be
        implemented by child classes.

        Parameters
        ----------
        x : torch.Tensor
            A tensor of shape (N, input_dim), where N is the batch size.

        Returns
        -------
        torch.Tensor
            A tensor of shape (N, output_dim) containing the field values (or predicted values).
        """
        assert False, "BaseNF does not implement eval()"

    def forward(self, x: torch.Tensor, transform=True) -> torch.Tensor:
        """
        Forward operator to derive field predictions from coordinate tensor. This internally calls whatever
        `eval` function the child class has implemented, after first applying relevant transforms to `x`.
        
        Note that this should generally not be called directly (see `self.predict(...)` instead).

        Parameters
        ----------
        x : torch.Tensor
            A tensor of shape (N, input_dim), where N is the batch size.
        transform : bool
            If True (default), any defined transform function is applied before encoding and evaluating the field for `x`.
            `x` should thus be expressed in model coordinates that will first be reconstructed into field coordinates using
            the defined transform function. Note that this should not be confused with the `curlew.core.Transform` object (`self.T`)
            used to then convert field coordinates to local coordinates.
            
        Returns
        -------
        torch.Tensor
            A tensor of shape (N, output_dim), representing the scalar potential.
        """
        # apply transform if needed
        if transform and self.transform is not None:
            x = self.transform(x, end=transform)
        
        # unwrap geode if x is a geode instance
        geode = None
        if isinstance(x, Geode):
            geode = x # store geode
            x = geode.x # extract coordinates [ these have been transformed during self.transform(x) ]

        # apply local transform to achieve e.g., global anisotropy
        x = self.T(x) 

        # evaluate drift
        out = 0
        if isinstance(self.drift, (int, float)):
            out = out + self.drift
        elif isinstance(self.drift, BaseSF):
            out = out + self.drift(x, transform=False)

        # evaluate field
        # N.B. `self.level` can be set to 0 to evaluate only the drift! Some types of field will also use self.level to controll the detail of reconstruction.
        # N.B.B. Analytical fields will always be evaluated, even at level 0, as these are not interpolations.
        if (self.level > 0) or isinstance(self, BaseAF):
            out = self.evaluate(x) + out
            if len(out.shape) == 1:
                out = out[:, None] # add extra dimension if needed (for consistency)

        # put back into geode?
        if geode:
            geode.scalar = out.squeeze()
            return geode
        else:
            return out # no need
    
    def bind( self, C ):
        """
        Bind a CSet to this field ready for loss computation (neural fields) or interpolation (interpolators).
        """
        self.C = C.torch() # make a copy

        # setup deltas for numerical differentiation if not yet defined
        C=self.C # shortand for our copy
        if C.grid is not None:
            if C.delta is None:
                # initialise differentiation step if needed
                C.delta = np.linalg.norm( C.grid.coords()[0,:] - C.grid.coords()[1,:] ) # / 2

            if C._offset is None:
                C._offset = []
                for i in range(self.input_dim):
                    o = [0]*self.input_dim
                    o[i] = C.delta
                    C._offset.append( _tensor( o, dev=curlew.device, dt=curlew.dtype) )

        # pre-allocate inequality clamp tensors
        # (layout matches loss: one block of ns per inequality)
        if C.iq is not None:
            ns = C.iq[0]
            n_iq = len(C.iq[1])
            total_ns = ns * n_iq
            C._iq_low_clamp = torch.empty(total_ns, dtype=curlew.dtype, device=curlew.device)
            C._iq_high_clamp = torch.empty(total_ns, dtype=curlew.dtype, device=curlew.device)
            offset = 0
            for _start, _end, iq in C.iq[1]:
                if '=' in iq:
                    C._iq_low_clamp[offset : offset + ns] = ninf
                    C._iq_high_clamp[offset : offset + ns] = inf
                elif '<' in iq:
                    C._iq_low_clamp[offset : offset + ns] = 0.0
                    C._iq_high_clamp[offset : offset + ns] = inf
                elif '>' in iq:
                    C._iq_low_clamp[offset : offset + ns] = ninf
                    C._iq_high_clamp[offset : offset + ns] = 0.0
                offset += ns

    def reset_mnorm(self):
        """Reset accumulation of average gradient magnitude"""
        self.mnorm = 0
        self.nnorm = 0

    def gradient(self, coords: torch.Tensor, 
                       normalize: bool = True, 
                       transform=True, 
                       return_value=False, 
                       retain_graph=False,
                       create_graph=False,
                       accumulate=True) -> torch.Tensor:
        """
        Compute the gradient of the scalar potential with respect to the input coordinates. Note that this only works
        for scalar (i.e. 1-D) fields.

        Parameters
        ----------
        coords : torch.Tensor
            A tensor of shape (N, input_dim) representing the input coordinates.
        normalize : bool, optional
            If True, the gradient is normalized to unit length per sample.
        transform : bool
            If True, any defined transform function is applied before encoding and evaluating the field for `coords`.
        return_value : bool, optional
            If True, both the gradient and the scalar value at the evaluated points are returned.
        retain_graph : bool, optional
            True if the gradient graph should be retained (to allow e.g., subsequent backpropagation). Default is False.
        create_graph : bool, optional
            True if the gradient value should have an underlying graph to allow it to influence back-prop operations. Default is False.
        accumulate : bool, optional
            True (optional) if the gradient evaluation should contribute to the average gradient estimate. Note that this averaging
            can be reset using `self.reset_mnorm` and accessed through `self.mnorm`.
        
        Returns
        -------
        torch.Tensor
            A tensor of shape (N, input_dim) representing the gradient of the scalar potential at each coordinate.
        torch.Tensor, optional
            A tensor of shape (N, 1) giving the scalar value at the evaluated points, if `return_value` is True.
        """

        # we need to compute gradients
        coords.requires_grad_(True)

        # Forward pass to get the model output and autodiff graph
        potential = self.forward(coords, transform=transform).sum(dim=-1)  # sum in case output_dim > 1

        # Compute gradient
        grad_out = torch.autograd.grad(
            outputs=potential,
            inputs=coords,
            grad_outputs=torch.ones_like(potential),
            create_graph=create_graph,
            retain_graph=retain_graph
        )[0]

        # Accumulate and/or normalize gradients?
        if accumulate or normalize:
            norm = torch.norm(grad_out, dim=-1, keepdim=True) + 1e-8
            if accumulate:
                self.mnorm = (self.mnorm*self.nnorm) + torch.mean(norm, axis=0).item()*len(norm) # update average gradeint
                self.nnorm += len(norm) # update counter holding number of observations
                self.mnorm = self.mnorm / self.nnorm # convert from total to average
            if normalize:
                grad_out = grad_out / norm
        
        # Return
        if return_value:
            return grad_out, potential
        else:
            return grad_out

    ## FITTING (STUBS)
    def loss(self, transform=True) -> torch.Tensor:
        """
        Optionally implemented by child classes to facilitate optimiation and learning. Defaults to 0.
        """
        return _tensor(0, dt=curlew.dtype, dev=curlew.device).requires_grad_(True), {self.name:(0,{})}
    
    def fit(self, *args):
        """
        Optionally implemented by learnable child classes. If not, simply returns whatever is returned by "loss".

        Returns
        -------
        loss : float
            The loss of the final result.
        details : dict
            A more detailed breakdown of the final loss. 
        """
        loss = self.loss()
        out = { self.name : [loss.item(),{}] }
        return loss, out

Base class for all implicit (scalar) fields, including interpolated, learned or analytical fields.

Initialise a new scalar field.

Parameters

name : str
A (ideally unique) name for this neural field. Should typically match the name of the GeoField instance that uses this field. Defaults to the name of this class.
input_dim : int, optional
The dimensionality of the input space (e.g., 3 for [x, y, z], 2 for [x,y]). If None (default) then default_dim is used.
output_dim : int, optional
Dimensionality of the output (usually 1 for a scalar potential).
C : CSet
Constraint sent used for learned or interpolated fields. Default is None.
H : HSet
Hyperparameters used to tune the loss function for this NF. Default is None.
drift : int | float | BaseSF
A constant integer or float (to use a constant value as the drift), or another BaseSF instance (e.g., an AnalyticalField) that defines the trend/drift of this field. This trend/drift will be evaluated at each input coordinate and added to the output of the field during the forward call, meaning learnable fields (interpolators) learn a residual relative to this drift. Default is 0 (no drift).
transform : callable
A function that transforms input coordinates prior to evaulation. Must take exactly one argument as input (a tensor of positions) and return the transformed positions.
local : curlew.core.Transform, optional
A Transform object defining the transform from (possibly undeformed) field coordinates to the local coordinates passed into the neural or analytical field representing this scalar field. This will be applied during the forward call, and can be used to e.g., implement global anisotropy, or tweak the represented structures by adding a constant offset or rotation. Defaults to an identity matrix (no transform).
seed : callable, optional
A random seed to (optinally) use for any random operations, if child classess wish.

Keywords

All keywords are passed to the initField(…) function of the child class, to build the relevant neural architecture.

Ancestors

Subclasses

Class variables

var level

The level of reconstruction to apply when evaluating this field. If "0" then only the drift/trend is returned. If "inf" then the highest level of detail is returned.

Values between 0 and np.inf will be treated differently by different field types. Default is np.inf.

Methods

def bind(self, C)
Expand source code
def bind( self, C ):
    """
    Bind a CSet to this field ready for loss computation (neural fields) or interpolation (interpolators).
    """
    self.C = C.torch() # make a copy

    # setup deltas for numerical differentiation if not yet defined
    C=self.C # shortand for our copy
    if C.grid is not None:
        if C.delta is None:
            # initialise differentiation step if needed
            C.delta = np.linalg.norm( C.grid.coords()[0,:] - C.grid.coords()[1,:] ) # / 2

        if C._offset is None:
            C._offset = []
            for i in range(self.input_dim):
                o = [0]*self.input_dim
                o[i] = C.delta
                C._offset.append( _tensor( o, dev=curlew.device, dt=curlew.dtype) )

    # pre-allocate inequality clamp tensors
    # (layout matches loss: one block of ns per inequality)
    if C.iq is not None:
        ns = C.iq[0]
        n_iq = len(C.iq[1])
        total_ns = ns * n_iq
        C._iq_low_clamp = torch.empty(total_ns, dtype=curlew.dtype, device=curlew.device)
        C._iq_high_clamp = torch.empty(total_ns, dtype=curlew.dtype, device=curlew.device)
        offset = 0
        for _start, _end, iq in C.iq[1]:
            if '=' in iq:
                C._iq_low_clamp[offset : offset + ns] = ninf
                C._iq_high_clamp[offset : offset + ns] = inf
            elif '<' in iq:
                C._iq_low_clamp[offset : offset + ns] = 0.0
                C._iq_high_clamp[offset : offset + ns] = inf
            elif '>' in iq:
                C._iq_low_clamp[offset : offset + ns] = ninf
                C._iq_high_clamp[offset : offset + ns] = 0.0
            offset += ns

Bind a CSet to this field ready for loss computation (neural fields) or interpolation (interpolators).

def evaluate(x)
Expand source code
def evaluate(x):
    """
    Evaluate this field at the specified input coordinates. This should be
    implemented by child classes.

    Parameters
    ----------
    x : torch.Tensor
        A tensor of shape (N, input_dim), where N is the batch size.

    Returns
    -------
    torch.Tensor
        A tensor of shape (N, output_dim) containing the field values (or predicted values).
    """
    assert False, "BaseNF does not implement eval()"

Evaluate this field at the specified input coordinates. This should be implemented by child classes.

Parameters

x : torch.Tensor
A tensor of shape (N, input_dim), where N is the batch size.

Returns

torch.Tensor
A tensor of shape (N, output_dim) containing the field values (or predicted values).
def fit(self, *args)
Expand source code
def fit(self, *args):
    """
    Optionally implemented by learnable child classes. If not, simply returns whatever is returned by "loss".

    Returns
    -------
    loss : float
        The loss of the final result.
    details : dict
        A more detailed breakdown of the final loss. 
    """
    loss = self.loss()
    out = { self.name : [loss.item(),{}] }
    return loss, out

Optionally implemented by learnable child classes. If not, simply returns whatever is returned by "loss".

Returns

loss : float
The loss of the final result.
details : dict
A more detailed breakdown of the final loss.
def forward(self, x: torch.Tensor, transform=True) ‑> torch.Tensor
Expand source code
def forward(self, x: torch.Tensor, transform=True) -> torch.Tensor:
    """
    Forward operator to derive field predictions from coordinate tensor. This internally calls whatever
    `eval` function the child class has implemented, after first applying relevant transforms to `x`.
    
    Note that this should generally not be called directly (see `self.predict(...)` instead).

    Parameters
    ----------
    x : torch.Tensor
        A tensor of shape (N, input_dim), where N is the batch size.
    transform : bool
        If True (default), any defined transform function is applied before encoding and evaluating the field for `x`.
        `x` should thus be expressed in model coordinates that will first be reconstructed into field coordinates using
        the defined transform function. Note that this should not be confused with the `curlew.core.Transform` object (`self.T`)
        used to then convert field coordinates to local coordinates.
        
    Returns
    -------
    torch.Tensor
        A tensor of shape (N, output_dim), representing the scalar potential.
    """
    # apply transform if needed
    if transform and self.transform is not None:
        x = self.transform(x, end=transform)
    
    # unwrap geode if x is a geode instance
    geode = None
    if isinstance(x, Geode):
        geode = x # store geode
        x = geode.x # extract coordinates [ these have been transformed during self.transform(x) ]

    # apply local transform to achieve e.g., global anisotropy
    x = self.T(x) 

    # evaluate drift
    out = 0
    if isinstance(self.drift, (int, float)):
        out = out + self.drift
    elif isinstance(self.drift, BaseSF):
        out = out + self.drift(x, transform=False)

    # evaluate field
    # N.B. `self.level` can be set to 0 to evaluate only the drift! Some types of field will also use self.level to controll the detail of reconstruction.
    # N.B.B. Analytical fields will always be evaluated, even at level 0, as these are not interpolations.
    if (self.level > 0) or isinstance(self, BaseAF):
        out = self.evaluate(x) + out
        if len(out.shape) == 1:
            out = out[:, None] # add extra dimension if needed (for consistency)

    # put back into geode?
    if geode:
        geode.scalar = out.squeeze()
        return geode
    else:
        return out # no need

Forward operator to derive field predictions from coordinate tensor. This internally calls whatever eval function the child class has implemented, after first applying relevant transforms to x.

Note that this should generally not be called directly (see self.predict(…) instead).

Parameters

x : torch.Tensor
A tensor of shape (N, input_dim), where N is the batch size.
transform : bool
If True (default), any defined transform function is applied before encoding and evaluating the field for x. x should thus be expressed in model coordinates that will first be reconstructed into field coordinates using the defined transform function. Note that this should not be confused with the curlew.core.Transform object (self.T) used to then convert field coordinates to local coordinates.

Returns

torch.Tensor
A tensor of shape (N, output_dim), representing the scalar potential.
def gradient(self,
coords: torch.Tensor,
normalize: bool = True,
transform=True,
return_value=False,
retain_graph=False,
create_graph=False,
accumulate=True) ‑> torch.Tensor
Expand source code
def gradient(self, coords: torch.Tensor, 
                   normalize: bool = True, 
                   transform=True, 
                   return_value=False, 
                   retain_graph=False,
                   create_graph=False,
                   accumulate=True) -> torch.Tensor:
    """
    Compute the gradient of the scalar potential with respect to the input coordinates. Note that this only works
    for scalar (i.e. 1-D) fields.

    Parameters
    ----------
    coords : torch.Tensor
        A tensor of shape (N, input_dim) representing the input coordinates.
    normalize : bool, optional
        If True, the gradient is normalized to unit length per sample.
    transform : bool
        If True, any defined transform function is applied before encoding and evaluating the field for `coords`.
    return_value : bool, optional
        If True, both the gradient and the scalar value at the evaluated points are returned.
    retain_graph : bool, optional
        True if the gradient graph should be retained (to allow e.g., subsequent backpropagation). Default is False.
    create_graph : bool, optional
        True if the gradient value should have an underlying graph to allow it to influence back-prop operations. Default is False.
    accumulate : bool, optional
        True (optional) if the gradient evaluation should contribute to the average gradient estimate. Note that this averaging
        can be reset using `self.reset_mnorm` and accessed through `self.mnorm`.
    
    Returns
    -------
    torch.Tensor
        A tensor of shape (N, input_dim) representing the gradient of the scalar potential at each coordinate.
    torch.Tensor, optional
        A tensor of shape (N, 1) giving the scalar value at the evaluated points, if `return_value` is True.
    """

    # we need to compute gradients
    coords.requires_grad_(True)

    # Forward pass to get the model output and autodiff graph
    potential = self.forward(coords, transform=transform).sum(dim=-1)  # sum in case output_dim > 1

    # Compute gradient
    grad_out = torch.autograd.grad(
        outputs=potential,
        inputs=coords,
        grad_outputs=torch.ones_like(potential),
        create_graph=create_graph,
        retain_graph=retain_graph
    )[0]

    # Accumulate and/or normalize gradients?
    if accumulate or normalize:
        norm = torch.norm(grad_out, dim=-1, keepdim=True) + 1e-8
        if accumulate:
            self.mnorm = (self.mnorm*self.nnorm) + torch.mean(norm, axis=0).item()*len(norm) # update average gradeint
            self.nnorm += len(norm) # update counter holding number of observations
            self.mnorm = self.mnorm / self.nnorm # convert from total to average
        if normalize:
            grad_out = grad_out / norm
    
    # Return
    if return_value:
        return grad_out, potential
    else:
        return grad_out

Compute the gradient of the scalar potential with respect to the input coordinates. Note that this only works for scalar (i.e. 1-D) fields.

Parameters

coords : torch.Tensor
A tensor of shape (N, input_dim) representing the input coordinates.
normalize : bool, optional
If True, the gradient is normalized to unit length per sample.
transform : bool
If True, any defined transform function is applied before encoding and evaluating the field for coords.
return_value : bool, optional
If True, both the gradient and the scalar value at the evaluated points are returned.
retain_graph : bool, optional
True if the gradient graph should be retained (to allow e.g., subsequent backpropagation). Default is False.
create_graph : bool, optional
True if the gradient value should have an underlying graph to allow it to influence back-prop operations. Default is False.
accumulate : bool, optional
True (optional) if the gradient evaluation should contribute to the average gradient estimate. Note that this averaging can be reset using self.reset_mnorm and accessed through self.mnorm.

Returns

torch.Tensor
A tensor of shape (N, input_dim) representing the gradient of the scalar potential at each coordinate.
torch.Tensor, optional
A tensor of shape (N, 1) giving the scalar value at the evaluated points, if return_value is True.
def initField(self, **kwargs)
Expand source code
def initField(self, **kwargs):
    """
    Build the internal structure of this implicit field. This should be implemented
    by child classes.
    """
    assert False, "BaseNF does not implement initField()"

Build the internal structure of this implicit field. This should be implemented by child classes.

def loss(self, transform=True) ‑> torch.Tensor
Expand source code
def loss(self, transform=True) -> torch.Tensor:
    """
    Optionally implemented by child classes to facilitate optimiation and learning. Defaults to 0.
    """
    return _tensor(0, dt=curlew.dtype, dev=curlew.device).requires_grad_(True), {self.name:(0,{})}

Optionally implemented by child classes to facilitate optimiation and learning. Defaults to 0.

def reset_mnorm(self)
Expand source code
def reset_mnorm(self):
    """Reset accumulation of average gradient magnitude"""
    self.mnorm = 0
    self.nnorm = 0

Reset accumulation of average gradient magnitude

Inherited members