Source code for jboat.utils.op_utils

import jittor as jit
from typing import List, Callable, Dict


[docs] class ResultStore: """ A simple class to store and manage intermediate results of hyper-gradient computation. """ def __init__(self): self.results = []
[docs] def add(self, name: str, result: Dict): """ Add a result to the store. Parameters ---------- name : str The name of the result (e.g., 'gradient_operator_results_0'). result : Dict The result dictionary to store. """ self.results.append({name: result})
[docs] def clear(self): """Clear all stored results.""" self.results = []
[docs] def get_results(self) -> List[Dict]: """Retrieve all stored results.""" return self.results
[docs] class DynamicalSystemRules: """ A class to store and manage gradient operator rules. """ # Default static gradient operator order _gradient_order = [ ["GDA", "DI"], ["NGD", "DM"], ]
[docs] @staticmethod def get_gradient_order() -> List[List[str]]: """ Get the current gradient operator order. Returns ------- List[List[str]] The current gradient operator order. """ return DynamicalSystemRules._gradient_order
[docs] @staticmethod def set_gradient_order(new_order: List[List[str]]): """ Set a new gradient operator order. Parameters ---------- new_order : List[List[str]] The new gradient operator order to set. Raises ------ ValueError If the new order is invalid. """ if not isinstance(new_order, list) or not all( isinstance(group, list) for group in new_order ): raise ValueError("Gradient order must be a list of lists.") DynamicalSystemRules._gradient_order = new_order
[docs] class HyperGradientRules: """ A class to store and manage gradient operator rules. """ # Default static gradient operator order _gradient_order = [ ["PTT", "FOA", "RGT"], ["IAD", "RAD", "FD", "IGA"], ["CG", "NS"], ]
[docs] @staticmethod def get_gradient_order() -> List[List[str]]: """ Get the current gradient operator order. Returns ------- List[List[str]] The current gradient operator order. """ return HyperGradientRules._gradient_order
[docs] @staticmethod def set_gradient_order(new_order: List[List[str]]): """ Set a new gradient operator order. Parameters ---------- new_order : List[List[str]] The new gradient operator order to set. Raises ------ ValueError If the new order is invalid. """ if not isinstance(new_order, list) or not all( isinstance(group, list) for group in new_order ): raise ValueError("Gradient order must be a list of lists.") HyperGradientRules._gradient_order = new_order
[docs] def l2_reg(parameters): """ Compute L2 regularization term (Jittor version). Parameters ---------- parameters : List[jt.Var] need to compute L2 regularization parameter list Returns ------- jt.Var L2 regularization loss """ loss = jit.zeros(1) for w in parameters: loss += (w * w).sum() return loss
[docs] def grad_unused_zero(output, inputs, retain_graph=False): """ Compute gradients for inputs with respect to the output, filling missing gradients with zeros. Parameters ---------- output : jittor.Var The output tensor to compute gradients for. inputs : List[jittor.Var] The input tensors to compute gradients with respect to. retain_graph : bool, optional Whether to retain the computation graph, by default False. Returns ------- Tuple[jittor.Var] Gradients with respect to the inputs, with zeros for unused gradients. """ grads = jit.grad(output, inputs, retain_graph=retain_graph) def grad_or_zeros(grad, var): return jit.zeros_like(var) if grad is None or (jit.isnan(grad).any()) else grad return tuple(grad_or_zeros(g, v) for g, v in zip(grads, inputs))
[docs] def list_tensor_matmul(list1, list2): """ Compute the element-wise multiplication and sum of two lists of tensors. Parameters ---------- list1 : List[jittor.Var] The first list of tensors. list2 : List[jittor.Var] The second list of tensors. Returns ------- jittor.Var The resulting scalar from element-wise multiplication and summation. """ out = 0 for t1, t2 in zip(list1, list2): out += (t1 * t2).sum() return out
[docs] def list_tensor_norm(list_tensor, p=2): """ Compute the p-norm of a list of tensors. Parameters ---------- list_tensor : List[jittor.Var] The list of tensors to compute the norm for. p : float, optional The order of the norm, by default 2 (Euclidean norm). Returns ------- jittor.Var The computed p-norm of the list of tensors. Raises ------ ValueError If the list of tensors is empty. """ norm = 0 for t in list_tensor: norm += (t.abs() ** p).sum() return norm ** (1 / p)
[docs] def require_model_grad(model=None): """ Ensure all model parameters require gradients. Parameters ---------- model : jittor.Module The model to check and update parameters. Raises ------ AssertionError If the model is not defined. """ assert model is not None, "The module is not defined!" for param in model.parameters(): if param.is_stop_grad(): param = param.clone()
[docs] def update_grads(grads, model): """ Update the custom_grad attribute of the model's parameters. Parameters ---------- grads : List[jittor.Var] Gradients to be applied to the parameters. model : jittor.Module Model whose parameters will be updated. """ for p, x in zip(grads, model.parameters()): if not hasattr(x, "_custom_grad"): x._custom_grad = p.clone() else: x._custom_grad += p
[docs] def manual_update(optimizer, variables): """ Manually update variables using gradients stored in _custom_grad. Parameters ---------- optimizer : jittor.optim.Optimizer The Jittor optimizer instance. variables : List[jittor.Var] A list of Jittor variables to be updated. """ variable_ids = {id(var) for var in variables} for group in optimizer.param_groups: lr = group.get("lr", optimizer.lr) for param in group["params"]: if id(param) not in variable_ids: continue # if param is not in variables, skip if not hasattr(param, "_custom_grad"): continue grad = param._custom_grad # if shapes do not match, skip if grad.shape != param.shape: continue param -= lr * grad param._custom_grad *= 0
[docs] def update_tensor_grads(hparams, grads): """ Update gradients for Jittor variables manually. Parameters ---------- hparams : List[jittor.Var] List of Jittor variables representing the hyperparameters. grads : List[jittor.Var] List of gradients corresponding to the hyperparameters. Raises ------ ValueError If a variable is stop_grad and cannot be updated. """ for l, g in zip(hparams, grads): if l.is_stop_grad(): continue if not hasattr(l, "_custom_grad"): l._custom_grad = g.clone().detach() else: l._custom_grad += g
[docs] def stop_grads(grads): """ Detach and stop gradient computation for a list of gradients. Parameters ---------- grads : List[jittor.Var] The gradients to process. Returns ------- List[jittor.Var] Detached gradients with requires_grad set to False. """ return [(grad.detach().stop_grad() if grad is not None else grad) for grad in grads]
[docs] def average_grad(model, batch_size): """ Divide the gradients of all model parameters by the batch size. Parameters ---------- model : jittor.Module The model whose gradients need to be averaged. batch_size : int The batch size to divide gradients by. """ for param in model.parameters(): if param.opt_grad() is not None: param.opt_grad().update(param.opt_grad() / batch_size)
[docs] def stop_model_grad(model=None): """ Stop gradient computation for all parameters in a model. Parameters ---------- model : jittor.Module The model to stop gradients for. Raises ------ AssertionError If the model is not defined. """ assert model is not None, "The module is not defined!" for param in model.parameters(): param.stop_grad()
[docs] def cat_list_to_tensor(list_tx): """ Concatenate a list of tensors into a single flattened tensor. Parameters ---------- list_tx : List[jittor.Var] The list of tensors to concatenate. Returns ------- jittor.Var A single flattened tensor. """ return jit.concat([xx.flatten() for xx in list_tx])
[docs] def copy_parameter_from_list(y, z): """ Copy parameters from a list to the parameters of a Jittor model. Parameters ---------- y : jittor.Module Jittor model with parameters to be updated. z : List[jittor.Var] List of variables to copy from. Returns ------- jittor.Module Updated model. """ for p, q in zip(y.parameters(), z): p.update(q.clone().detach()) p.requires_grad = True return y
[docs] def get_outer_gradients(outer_loss, params, hparams, retain_graph=True): """ Compute gradients of the outer-level loss with respect to parameters and hyperparameters. Parameters ---------- outer_loss : jittor.Var The scalar loss from the outer-level optimization problem. params : List[jittor.Var] The list of parameters for which gradients with respect to the outer loss are computed. hparams : List[jittor.Var] The list of hyperparameters for which gradients with respect to the outer loss are computed. retain_graph : bool, optional Whether to retain the computation graph after computing the gradients, by default True. Returns ------- Tuple[List[jittor.Var], List[jittor.Var]] Gradients with respect to parameters and hyperparameters. """ grad_outer_w = grad_unused_zero(outer_loss, params, retain_graph=retain_graph) grad_outer_hparams = grad_unused_zero( outer_loss, hparams, retain_graph=retain_graph ) return grad_outer_w, grad_outer_hparams
[docs] def custom_grad(outputs, inputs, grad_outputs=None, retain_graph=False): """ Compute the vector-Jacobian product for Jittor, mimicking PyTorch's autograd.grad. Parameters ---------- outputs : Sequence[jittor.Var] Outputs of the differentiated function. inputs : Sequence[jittor.Var] Inputs with respect to which the gradient will be computed. grad_outputs : Sequence[jittor.Var], optional Gradients with respect to the outputs, by default None. retain_graph : bool, optional Whether to retain the computation graph after computing the gradients, by default False. Returns ------- List[jittor.Var] Gradients with respect to the inputs. """ if not isinstance(outputs, (tuple, list)): outputs = (outputs,) if grad_outputs is None: grad_outputs = [jit.ones_like(output) for output in outputs] elif not isinstance(grad_outputs, (tuple, list)): grad_outputs = (grad_outputs,) assert len(outputs) == len( grad_outputs ), "outputs and grad_outputs must have the same length." total_output = sum( (output * grad_output).sum() for output, grad_output in zip(outputs, grad_outputs) ) grads = jit.grad(total_output, inputs, retain_graph=retain_graph) return grads
[docs] def neumann(params, hparams, upper_loss, lower_loss, k, fp_map, tol=1e-10): """ Compute hyperparameter gradients using the Neumann series approximation. Parameters ---------- params : List[jittor.Var] List of parameters for the lower-level optimization problem. hparams : List[jittor.Var] List of hyperparameters for the upper-level optimization problem. upper_loss : jittor.Var Loss function for the upper-level problem. lower_loss : jittor.Var Loss function for the lower-level problem. k : int Number of iterations for the Neumann series approximation. fp_map : Callable Fixed-point map function that computes updates to lower-level parameters. tol : float, optional Tolerance for early stopping based on convergence, by default 1e-10. Returns ------- List[jittor.Var] Hyperparameter gradients computed using the Neumann series approximation. """ grad_outer_w, grad_outer_hparams = get_outer_gradients(upper_loss, params, hparams) w_mapped = fp_map(params, lower_loss) vs, gs = grad_outer_w, grad_outer_w gs_vec = cat_list_to_tensor(gs) for i in range(k): gs_prev_vec = gs_vec vs = custom_grad(w_mapped, params, grad_outputs=vs, retain_graph=True) gs = [g + v for g, v in zip(gs, vs)] gs_vec = cat_list_to_tensor(gs) if float(jit.norm(gs_vec - gs_prev_vec)) < tol: break grads = custom_grad(w_mapped, hparams, grad_outputs=gs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] return grads
[docs] def conjugate_gradient( params, hparams, upper_loss, lower_loss, K, fp_map, tol=1e-10, stochastic=False ): """ Compute hyperparameter gradients using the Conjugate Gradient method. Parameters ---------- params : List[jittor.Var] List of parameters for the lower-level optimization problem. hparams : List[jittor.Var] List of hyperparameters for the upper-level optimization problem. upper_loss : jittor.Var Loss function for the upper-level problem. lower_loss : jittor.Var Loss function for the lower-level problem. K : int Maximum number of iterations for the Conjugate Gradient method. fp_map : Callable Fixed-point map function that computes updates to lower-level parameters. tol : float, optional Tolerance for early stopping based on convergence, by default 1e-10. stochastic : bool, optional If True, recompute the fixed-point map during each iteration, by default False. Returns ------- List[jittor.Var] Hyperparameter gradients computed using the Conjugate Gradient method. """ grad_outer_w, grad_outer_hparams = get_outer_gradients(upper_loss, params, hparams) if not stochastic: w_mapped = fp_map(params, lower_loss) def dfp_map_dw(xs): if stochastic: w_mapped_in = fp_map(params, lower_loss) Jfp_mapTv = custom_grad( w_mapped_in, params, grad_outputs=xs, retain_graph=False ) else: Jfp_mapTv = custom_grad( w_mapped, params, grad_outputs=xs, retain_graph=True ) return [v - j for v, j in zip(xs, Jfp_mapTv)] vs = cg_step(dfp_map_dw, grad_outer_w, max_iter=K, epsilon=tol) if stochastic: w_mapped = fp_map(params, lower_loss) grads = custom_grad(w_mapped, hparams, grad_outputs=vs) grads = [g + v for g, v in zip(grads, grad_outer_hparams)] return grads
[docs] def cg_step(Ax, b, max_iter=100, epsilon=1.0e-5): """ Perform Conjugate Gradient (CG) optimization to solve Ax = b. Parameters ---------- Ax : Callable Function that computes the matrix-vector product Ax for a given x. b : List[jittor.Var] Right-hand side of the equation Ax = b. max_iter : int, optional Maximum number of iterations for the CG method, by default 100. epsilon : float, optional Convergence threshold for the residual norm, by default 1e-5. Returns ------- List[jittor.Var] Solution vector x that approximately solves Ax = b. """ x_last = [jit.zeros_like(bb) for bb in b] r_last = [jit.zeros_like(bb) + bb for bb in b] p_last = [jit.zeros_like(rr) + rr for rr in r_last] for _ in range(max_iter): Ap = Ax(p_last) Ap_vec = cat_list_to_tensor(Ap) p_last_vec = cat_list_to_tensor(p_last) r_last_vec = cat_list_to_tensor(r_last) rTr = jit.sum(r_last_vec * r_last_vec) pAp = jit.sum(p_last_vec * Ap_vec) alpha = rTr / pAp x = [xx + alpha * pp for xx, pp in zip(x_last, p_last)] r = [rr - alpha * pp for rr, pp in zip(r_last, Ap)] r_vec = cat_list_to_tensor(r) if float(jit.norm(r_vec)) < epsilon: break beta = jit.sum(r_vec * r_vec) / rTr p = [rr + beta * pp for rr, pp in zip(r, p_last)] x_last = x p_last = p r_last = r return x_last