jboat.utils

Submodules

jboat.utils.op_utils

class jboat.utils.op_utils.DynamicalSystemRules[source]

Bases: object

A class to store and manage gradient operator rules.

static get_gradient_order()[source]

Get the current gradient operator order.

Returns:

The current gradient operator order.

Return type:

List[List[str]]

static set_gradient_order(new_order)[source]

Set a new gradient operator order.

Parameters:

new_order (List[List[str]]) – The new gradient operator order to set.

Raises:

ValueError – If the new order is invalid.

class jboat.utils.op_utils.HyperGradientRules[source]

Bases: object

A class to store and manage gradient operator rules.

static get_gradient_order()[source]

Get the current gradient operator order.

Returns:

The current gradient operator order.

Return type:

List[List[str]]

static set_gradient_order(new_order)[source]

Set a new gradient operator order.

Parameters:

new_order (List[List[str]]) – The new gradient operator order to set.

Raises:

ValueError – If the new order is invalid.

class jboat.utils.op_utils.ResultStore[source]

Bases: object

A simple class to store and manage intermediate results of hyper-gradient computation.

add(name, result)[source]

Add a result to the store.

Parameters:
  • name (str) – The name of the result (e.g., ‘gradient_operator_results_0’).

  • result (Dict) – The result dictionary to store.

clear()[source]

Clear all stored results.

get_results()[source]

Retrieve all stored results.

jboat.utils.op_utils.average_grad(model, batch_size)[source]

Divide the gradients of all model parameters by the batch size.

Parameters:
  • model (jittor.Module) – The model whose gradients need to be averaged.

  • batch_size (int) – The batch size to divide gradients by.

jboat.utils.op_utils.cat_list_to_tensor(list_tx)[source]

Concatenate a list of tensors into a single flattened tensor.

Parameters:

list_tx (List[jittor.Var]) – The list of tensors to concatenate.

Returns:

A single flattened tensor.

Return type:

jittor.Var

jboat.utils.op_utils.cg_step(Ax, b, max_iter=100, epsilon=1e-05)[source]

Perform Conjugate Gradient (CG) optimization to solve Ax = b.

Parameters:
  • Ax (Callable) – Function that computes the matrix-vector product Ax for a given x.

  • b (List[jittor.Var]) – Right-hand side of the equation Ax = b.

  • max_iter (int, optional) – Maximum number of iterations for the CG method, by default 100.

  • epsilon (float, optional) – Convergence threshold for the residual norm, by default 1e-5.

Returns:

Solution vector x that approximately solves Ax = b.

Return type:

List[jittor.Var]

jboat.utils.op_utils.conjugate_gradient(params, hparams, upper_loss, lower_loss, K, fp_map, tol=1e-10, stochastic=False)[source]

Compute hyperparameter gradients using the Conjugate Gradient method.

Parameters:
  • params (List[jittor.Var]) – List of parameters for the lower-level optimization problem.

  • hparams (List[jittor.Var]) – List of hyperparameters for the upper-level optimization problem.

  • upper_loss (jittor.Var) – Loss function for the upper-level problem.

  • lower_loss (jittor.Var) – Loss function for the lower-level problem.

  • K (int) – Maximum number of iterations for the Conjugate Gradient method.

  • fp_map (Callable) – Fixed-point map function that computes updates to lower-level parameters.

  • tol (float, optional) – Tolerance for early stopping based on convergence, by default 1e-10.

  • stochastic (bool, optional) – If True, recompute the fixed-point map during each iteration, by default False.

Returns:

Hyperparameter gradients computed using the Conjugate Gradient method.

Return type:

List[jittor.Var]

jboat.utils.op_utils.copy_parameter_from_list(y, z)[source]

Copy parameters from a list to the parameters of a Jittor model.

Parameters:
  • y (jittor.Module) – Jittor model with parameters to be updated.

  • z (List[jittor.Var]) – List of variables to copy from.

Returns:

Updated model.

Return type:

jittor.Module

jboat.utils.op_utils.custom_grad(outputs, inputs, grad_outputs=None, retain_graph=False)[source]

Compute the vector-Jacobian product for Jittor, mimicking PyTorch’s autograd.grad.

Parameters:
  • outputs (Sequence[jittor.Var]) – Outputs of the differentiated function.

  • inputs (Sequence[jittor.Var]) – Inputs with respect to which the gradient will be computed.

  • grad_outputs (Sequence[jittor.Var], optional) – Gradients with respect to the outputs, by default None.

  • retain_graph (bool, optional) – Whether to retain the computation graph after computing the gradients, by default False.

Returns:

Gradients with respect to the inputs.

Return type:

List[jittor.Var]

jboat.utils.op_utils.get_outer_gradients(outer_loss, params, hparams, retain_graph=True)[source]

Compute gradients of the outer-level loss with respect to parameters and hyperparameters.

Parameters:
  • outer_loss (jittor.Var) – The scalar loss from the outer-level optimization problem.

  • params (List[jittor.Var]) – The list of parameters for which gradients with respect to the outer loss are computed.

  • hparams (List[jittor.Var]) – The list of hyperparameters for which gradients with respect to the outer loss are computed.

  • retain_graph (bool, optional) – Whether to retain the computation graph after computing the gradients, by default True.

Returns:

Gradients with respect to parameters and hyperparameters.

Return type:

Tuple[List[jittor.Var], List[jittor.Var]]

jboat.utils.op_utils.grad_unused_zero(output, inputs, retain_graph=False)[source]

Compute gradients for inputs with respect to the output, filling missing gradients with zeros.

Parameters:
  • output (jittor.Var) – The output tensor to compute gradients for.

  • inputs (List[jittor.Var]) – The input tensors to compute gradients with respect to.

  • retain_graph (bool, optional) – Whether to retain the computation graph, by default False.

Returns:

Gradients with respect to the inputs, with zeros for unused gradients.

Return type:

Tuple[jittor.Var]

jboat.utils.op_utils.l2_reg(parameters)[source]

Compute L2 regularization term (Jittor version).

Parameters:

parameters (List[jt.Var]) – need to compute L2 regularization parameter list

Returns:

L2 regularization loss

Return type:

jt.Var

jboat.utils.op_utils.list_tensor_matmul(list1, list2)[source]

Compute the element-wise multiplication and sum of two lists of tensors.

Parameters:
  • list1 (List[jittor.Var]) – The first list of tensors.

  • list2 (List[jittor.Var]) – The second list of tensors.

Returns:

The resulting scalar from element-wise multiplication and summation.

Return type:

jittor.Var

jboat.utils.op_utils.list_tensor_norm(list_tensor, p=2)[source]

Compute the p-norm of a list of tensors.

Parameters:
  • list_tensor (List[jittor.Var]) – The list of tensors to compute the norm for.

  • p (float, optional) – The order of the norm, by default 2 (Euclidean norm).

Returns:

The computed p-norm of the list of tensors.

Return type:

jittor.Var

Raises:

ValueError – If the list of tensors is empty.

jboat.utils.op_utils.manual_update(optimizer, variables)[source]

Manually update variables using gradients stored in _custom_grad.

Parameters:
  • optimizer (jittor.optim.Optimizer) – The Jittor optimizer instance.

  • variables (List[jittor.Var]) – A list of Jittor variables to be updated.

jboat.utils.op_utils.neumann(params, hparams, upper_loss, lower_loss, k, fp_map, tol=1e-10)[source]

Compute hyperparameter gradients using the Neumann series approximation.

Parameters:
  • params (List[jittor.Var]) – List of parameters for the lower-level optimization problem.

  • hparams (List[jittor.Var]) – List of hyperparameters for the upper-level optimization problem.

  • upper_loss (jittor.Var) – Loss function for the upper-level problem.

  • lower_loss (jittor.Var) – Loss function for the lower-level problem.

  • k (int) – Number of iterations for the Neumann series approximation.

  • fp_map (Callable) – Fixed-point map function that computes updates to lower-level parameters.

  • tol (float, optional) – Tolerance for early stopping based on convergence, by default 1e-10.

Returns:

Hyperparameter gradients computed using the Neumann series approximation.

Return type:

List[jittor.Var]

jboat.utils.op_utils.require_model_grad(model=None)[source]

Ensure all model parameters require gradients.

Parameters:

model (jittor.Module) – The model to check and update parameters.

Raises:

AssertionError – If the model is not defined.

jboat.utils.op_utils.stop_grads(grads)[source]

Detach and stop gradient computation for a list of gradients.

Parameters:

grads (List[jittor.Var]) – The gradients to process.

Returns:

Detached gradients with requires_grad set to False.

Return type:

List[jittor.Var]

jboat.utils.op_utils.stop_model_grad(model=None)[source]

Stop gradient computation for all parameters in a model.

Parameters:

model (jittor.Module) – The model to stop gradients for.

Raises:

AssertionError – If the model is not defined.

jboat.utils.op_utils.update_grads(grads, model)[source]

Update the custom_grad attribute of the model’s parameters.

Parameters:
  • grads (List[jittor.Var]) – Gradients to be applied to the parameters.

  • model (jittor.Module) – Model whose parameters will be updated.

jboat.utils.op_utils.update_tensor_grads(hparams, grads)[source]

Update gradients for Jittor variables manually.

Parameters:
  • hparams (List[jittor.Var]) – List of Jittor variables representing the hyperparameters.

  • grads (List[jittor.Var]) – List of gradients corresponding to the hyperparameters.

Raises:

ValueError – If a variable is stop_grad and cannot be updated.