jboat.utils
Submodules
jboat.utils.op_utils
- class jboat.utils.op_utils.DynamicalSystemRules[source]
Bases:
objectA class to store and manage gradient operator rules.
- class jboat.utils.op_utils.HyperGradientRules[source]
Bases:
objectA class to store and manage gradient operator rules.
- class jboat.utils.op_utils.ResultStore[source]
Bases:
objectA simple class to store and manage intermediate results of hyper-gradient computation.
- jboat.utils.op_utils.average_grad(model, batch_size)[source]
Divide the gradients of all model parameters by the batch size.
- Parameters:
model (jittor.Module) – The model whose gradients need to be averaged.
batch_size (int) – The batch size to divide gradients by.
- jboat.utils.op_utils.cat_list_to_tensor(list_tx)[source]
Concatenate a list of tensors into a single flattened tensor.
- Parameters:
list_tx (List[jittor.Var]) – The list of tensors to concatenate.
- Returns:
A single flattened tensor.
- Return type:
jittor.Var
- jboat.utils.op_utils.cg_step(Ax, b, max_iter=100, epsilon=1e-05)[source]
Perform Conjugate Gradient (CG) optimization to solve Ax = b.
- Parameters:
Ax (Callable) – Function that computes the matrix-vector product Ax for a given x.
b (List[jittor.Var]) – Right-hand side of the equation Ax = b.
max_iter (int, optional) – Maximum number of iterations for the CG method, by default 100.
epsilon (float, optional) – Convergence threshold for the residual norm, by default 1e-5.
- Returns:
Solution vector x that approximately solves Ax = b.
- Return type:
List[jittor.Var]
- jboat.utils.op_utils.conjugate_gradient(params, hparams, upper_loss, lower_loss, K, fp_map, tol=1e-10, stochastic=False)[source]
Compute hyperparameter gradients using the Conjugate Gradient method.
- Parameters:
params (List[jittor.Var]) – List of parameters for the lower-level optimization problem.
hparams (List[jittor.Var]) – List of hyperparameters for the upper-level optimization problem.
upper_loss (jittor.Var) – Loss function for the upper-level problem.
lower_loss (jittor.Var) – Loss function for the lower-level problem.
K (int) – Maximum number of iterations for the Conjugate Gradient method.
fp_map (Callable) – Fixed-point map function that computes updates to lower-level parameters.
tol (float, optional) – Tolerance for early stopping based on convergence, by default 1e-10.
stochastic (bool, optional) – If True, recompute the fixed-point map during each iteration, by default False.
- Returns:
Hyperparameter gradients computed using the Conjugate Gradient method.
- Return type:
List[jittor.Var]
- jboat.utils.op_utils.copy_parameter_from_list(y, z)[source]
Copy parameters from a list to the parameters of a Jittor model.
- Parameters:
y (jittor.Module) – Jittor model with parameters to be updated.
z (List[jittor.Var]) – List of variables to copy from.
- Returns:
Updated model.
- Return type:
jittor.Module
- jboat.utils.op_utils.custom_grad(outputs, inputs, grad_outputs=None, retain_graph=False)[source]
Compute the vector-Jacobian product for Jittor, mimicking PyTorch’s autograd.grad.
- Parameters:
outputs (Sequence[jittor.Var]) – Outputs of the differentiated function.
inputs (Sequence[jittor.Var]) – Inputs with respect to which the gradient will be computed.
grad_outputs (Sequence[jittor.Var], optional) – Gradients with respect to the outputs, by default None.
retain_graph (bool, optional) – Whether to retain the computation graph after computing the gradients, by default False.
- Returns:
Gradients with respect to the inputs.
- Return type:
List[jittor.Var]
- jboat.utils.op_utils.get_outer_gradients(outer_loss, params, hparams, retain_graph=True)[source]
Compute gradients of the outer-level loss with respect to parameters and hyperparameters.
- Parameters:
outer_loss (jittor.Var) – The scalar loss from the outer-level optimization problem.
params (List[jittor.Var]) – The list of parameters for which gradients with respect to the outer loss are computed.
hparams (List[jittor.Var]) – The list of hyperparameters for which gradients with respect to the outer loss are computed.
retain_graph (bool, optional) – Whether to retain the computation graph after computing the gradients, by default True.
- Returns:
Gradients with respect to parameters and hyperparameters.
- Return type:
Tuple[List[jittor.Var], List[jittor.Var]]
- jboat.utils.op_utils.grad_unused_zero(output, inputs, retain_graph=False)[source]
Compute gradients for inputs with respect to the output, filling missing gradients with zeros.
- Parameters:
output (jittor.Var) – The output tensor to compute gradients for.
inputs (List[jittor.Var]) – The input tensors to compute gradients with respect to.
retain_graph (bool, optional) – Whether to retain the computation graph, by default False.
- Returns:
Gradients with respect to the inputs, with zeros for unused gradients.
- Return type:
Tuple[jittor.Var]
- jboat.utils.op_utils.l2_reg(parameters)[source]
Compute L2 regularization term (Jittor version).
- Parameters:
parameters (List[jt.Var]) – need to compute L2 regularization parameter list
- Returns:
L2 regularization loss
- Return type:
jt.Var
- jboat.utils.op_utils.list_tensor_matmul(list1, list2)[source]
Compute the element-wise multiplication and sum of two lists of tensors.
- Parameters:
list1 (List[jittor.Var]) – The first list of tensors.
list2 (List[jittor.Var]) – The second list of tensors.
- Returns:
The resulting scalar from element-wise multiplication and summation.
- Return type:
jittor.Var
- jboat.utils.op_utils.list_tensor_norm(list_tensor, p=2)[source]
Compute the p-norm of a list of tensors.
- Parameters:
list_tensor (List[jittor.Var]) – The list of tensors to compute the norm for.
p (float, optional) – The order of the norm, by default 2 (Euclidean norm).
- Returns:
The computed p-norm of the list of tensors.
- Return type:
jittor.Var
- Raises:
ValueError – If the list of tensors is empty.
- jboat.utils.op_utils.manual_update(optimizer, variables)[source]
Manually update variables using gradients stored in _custom_grad.
- Parameters:
optimizer (jittor.optim.Optimizer) – The Jittor optimizer instance.
variables (List[jittor.Var]) – A list of Jittor variables to be updated.
- jboat.utils.op_utils.neumann(params, hparams, upper_loss, lower_loss, k, fp_map, tol=1e-10)[source]
Compute hyperparameter gradients using the Neumann series approximation.
- Parameters:
params (List[jittor.Var]) – List of parameters for the lower-level optimization problem.
hparams (List[jittor.Var]) – List of hyperparameters for the upper-level optimization problem.
upper_loss (jittor.Var) – Loss function for the upper-level problem.
lower_loss (jittor.Var) – Loss function for the lower-level problem.
k (int) – Number of iterations for the Neumann series approximation.
fp_map (Callable) – Fixed-point map function that computes updates to lower-level parameters.
tol (float, optional) – Tolerance for early stopping based on convergence, by default 1e-10.
- Returns:
Hyperparameter gradients computed using the Neumann series approximation.
- Return type:
List[jittor.Var]
- jboat.utils.op_utils.require_model_grad(model=None)[source]
Ensure all model parameters require gradients.
- Parameters:
model (jittor.Module) – The model to check and update parameters.
- Raises:
AssertionError – If the model is not defined.
- jboat.utils.op_utils.stop_grads(grads)[source]
Detach and stop gradient computation for a list of gradients.
- Parameters:
grads (List[jittor.Var]) – The gradients to process.
- Returns:
Detached gradients with requires_grad set to False.
- Return type:
List[jittor.Var]
- jboat.utils.op_utils.stop_model_grad(model=None)[source]
Stop gradient computation for all parameters in a model.
- Parameters:
model (jittor.Module) – The model to stop gradients for.
- Raises:
AssertionError – If the model is not defined.
- jboat.utils.op_utils.update_grads(grads, model)[source]
Update the custom_grad attribute of the model’s parameters.
- Parameters:
grads (List[jittor.Var]) – Gradients to be applied to the parameters.
model (jittor.Module) – Model whose parameters will be updated.
- jboat.utils.op_utils.update_tensor_grads(hparams, grads)[source]
Update gradients for Jittor variables manually.
- Parameters:
hparams (List[jittor.Var]) – List of Jittor variables representing the hyperparameters.
grads (List[jittor.Var]) – List of gradients corresponding to the hyperparameters.
- Raises:
ValueError – If a variable is stop_grad and cannot be updated.