jboat.na_ol
Submodules
jboat.na_ol.hyper_gradient
- class jboat.na_ol.hyper_gradient.HyperGradient(ll_objective, ul_objective, ul_model, ll_model, ll_var, ul_var, solver_config)[source]
Bases:
objectBase class for computing hyper-gradients of upper-level variables in bilevel optimization problems.
This class provides an abstract interface for hyper-gradient computation that can be extended for specific methods such as Conjugate Gradient, Finite Differentiation, or First-Order Approximation.
- Parameters:
ll_objective (callable) – The lower-level objective function of the bilevel optimization problem.
ul_objective (callable) – The upper-level objective function of the bilevel optimization problem.
ul_model (jittor.Module) – The upper-level model of the bilevel optimization problem.
ll_model (jittor.Module) – The lower-level model of the bilevel optimization problem.
ll_var (List[jittor.Var]) – A list of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – A list of variables optimized with the upper-level objective.
solver_config (dict) – Dictionary containing configurations for the solver.
jboat.na_ol.sequential_hg
- class jboat.na_ol.sequential_hg.SequentialHG(ordered_instances, custom_order)[source]
Bases:
objectA class for managing sequential hyper-gradient operations.
This class dynamically organizes and executes a sequence of hyper-gradient computations using user-defined and validated orders of gradient operators.
- Parameters:
ordered_instances (List[object]) – A list of instantiated gradient operator objects, ordered as per the adjusted sequence.
custom_order (List[str]) – The user-defined order of gradient operators.
- compute_gradients(**kwargs)[source]
Compute hyper-gradients sequentially using the ordered instances.
This method processes the hyper-gradients in the defined order, passing intermediate results between consecutive gradient operators.
- Parameters:
**kwargs (dict) – Additional arguments required for gradient computations.
- Returns:
A list of dictionaries containing results for each gradient operator.
- Return type:
List[Dict]
- jboat.na_ol.sequential_hg.makes_functional_na_operation(custom_order, **kwargs)[source]
Dynamically create a SequentialHG object with ordered gradient operators.
This function validates the user-defined operator order, adjusts it to conform with predefined gradient rules, and dynamically loads the corresponding operator classes.
- Parameters:
custom_order (List[str]) – The user-defined order of gradient operators.
**kwargs (dict) – Additional arguments required for initializing gradient operator instances.
- Returns:
An instance of SequentialHG containing the ordered gradient operators and result management.
- Return type:
- jboat.na_ol.sequential_hg.validate_and_adjust_order(custom_order, gradient_order)[source]
Validate and adjust the custom order to match the predefined gradient order.
This function ensures that the user-defined order adheres to the predefined grouping rules and adjusts it accordingly.
- Parameters:
custom_order (List[str]) – The user-provided order of gradient operators.
gradient_order (List[List[str]]) – The predefined order of gradient operator groups.
- Returns:
Adjusted order of gradient operators following the predefined rules.
- Return type:
List[str]
jboat.na_ol.cg
- class jboat.na_ol.cg.CG(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Conjugate Gradient (CG) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) –
Dictionary containing solver configurations. Expected keys include:
r (float): Perturbation radius for finite differences.
lower_level_opt (jittor.optim.Optimizer): Lower-level optimizer configuration.
gm_op (str): Indicates dynamic initialization type (e.g., “DI”).
- GDA-specific parameters if applicable, such as:
alpha_init (float): Initial learning rate for GDA.
alpha_decay (float): Decay factor for GDA.
- ll_lr
Learning rate for the lower-level optimizer, extracted from lower_level_opt.
- Type:
float
- dynamic_initialization
Indicates whether dynamic initialization is enabled (based on gm_op).
- Type:
bool
- tolerance
The tolerance for approximation.
- Type:
float
- K
Number of iterations for CG approximation.
- Type:
int
- alpha
Initial learning rate for GDA operations.
- Type:
float
- alpha_decay
Decay factor applied to the learning rate for GDA.
- Type:
float
- gda_loss
Custom loss function for GDA operations, if specified in solver_config.
- Type:
Callable, optional
References
[1] Pedregosa F. “Hyperparameter optimization with approximate gradient,” in ICML, 2016.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.
hyper_gradient_finished (bool, optional) – A flag indicating whether the hyper-gradient computation is finished. Default is False.
next_operation (str, optional) – The next operator for the calculation of the hypergradient. Default is None.
**kwargs (dict) – Additional arguments, such as: - lower_model_params (list): Parameters of the lower-level model (default: list(auxiliary_model.parameters())). - hparams (list): Hyper-parameters of the upper-level model (default: list(self.ul_var)).
- Returns:
A dictionary containing: - “upper_loss”: The current upper-level objective value. - “hyper_gradient_finished”: A boolean indicating that the hyper-gradient computation is complete.
- Return type:
dict
- Returns:
A dictionary containing the upper-level objective and the status of hypergradient computation.
- Return type:
Dict
jboat.na_ol.fd
- class jboat.na_ol.fd.FD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Finite Differentiation (FD) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) –
Dictionary containing solver configurations. Expected keys include:
r (float): Perturbation radius for finite differences.
lower_level_opt (jittor.optim.Optimizer): Lower-level optimizer configuration.
gm_op (str): Indicates dynamic initialization type (e.g., “DI”).
- GDA-specific parameters if applicable, such as:
alpha_init (float): Initial learning rate for GDA.
alpha_decay (float): Decay factor for GDA.
- ll_lr
Learning rate for the lower-level optimizer, extracted from lower_level_opt.
- Type:
float
- dynamic_initialization
Indicates whether dynamic initialization is enabled (based on gm_op).
- Type:
bool
- _r
Perturbation radius for finite differences, used for gradient computation.
- Type:
float
- alpha
Initial learning rate for GDA operations.
- Type:
float
- alpha_decay
Decay factor applied to the learning rate for GDA.
- Type:
float
- gda_loss
Custom loss function for GDA operations, if specified in solver_config.
- Type:
Callable, optional
References
[1] H. Liu, K. Simonyan, Y. Yang, “DARTS: Differentiable Architecture Search,” in ICLR, 2019.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.
hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hyper-gradient computation is finished. Default is False.
next_operation (str, optional) – The next operator for the calculation of the hyper-gradient. Default is None.
- Returns:
A dictionary containing: - “upper_loss”: The current upper-level objective value. - “hyper_gradient_finished”: A boolean indicating whether the hyper-gradient computation is complete.
- Return type:
dict
- Raises:
AssertionError – If next_operation is not None, as FD does not support next_operation.
jboat.na_ol.foa
- class jboat.na_ol.foa.FOA(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using First-Order Approximation (FOA) [1], leveraging Initialization-based Auto Differentiation (IAD) [2].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) – Dictionary containing solver configurations.
References
[1] Nichol A., “On first-order meta-learning algorithms,” arXiv preprint arXiv:1803.02999, 2018. [2] Finn C., Abbeel P., Levine S., “Model-agnostic meta-learning for fast adaptation of deep networks”, in ICML, 2017.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables using the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower-level model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.
hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished, by default False.
next_operation (str, optional) – The next operator for the calculation of the hypergradient, by default None.
kwargs (dict) – Additional keyword arguments.
- Returns:
A dictionary containing information required for the next step in the hypergradient computation, including the feed dictionaries, auxiliary model, iteration count, and other optional arguments.
- Return type:
Dict
- Raises:
AssertionError – If next_operation is not defined or if hyper_gradient_finished is True.
jboat.na_ol.iad
- class jboat.na_ol.iad.IAD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Initialization-based Auto Differentiation (IAD) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) – Dictionary containing solver configurations.
References
[1] Finn C., Abbeel P., Levine S., “Model-agnostic meta-learning for fast adaptation of deep networks”, in ICML, 2017.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int) – The number of iterations used for backpropagation.
next_operation (str) – The next operator for the calculation of the hypergradient.
hyper_gradient_finished (bool) – A boolean flag indicating whether the hypergradient computation is finished.
- Returns:
A dictionary containing the upper-level objective and the status of hypergradient computation.
- Return type:
Dict
jboat.na_ol.iga
- class jboat.na_ol.iga.IGA(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Implicit Gradient Approximation (IGA) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) –
Dictionary containing solver configurations, including:
alpha_init (float): Initial learning rate for GDA.
alpha_decay (float): Decay factor for the GDA learning rate.
Optional gda_loss (Callable): Custom loss function for GDA, if applicable.
gm_op (List[str]): Specifies dynamic operations, e.g., “DI” for dynamic initialization.
- alpha
Initial learning rate for GDA operations, if applicable.
- Type:
float
- alpha_decay
Decay factor applied to the GDA learning rate.
- Type:
float
- gda_loss
Custom loss function for GDA operations, if specified in solver_config.
- Type:
Callable, optional
- dynamic_initialization
Indicates whether dynamic initialization is enabled, based on gm_op.
- Type:
bool
References
- [1] Liu R, Gao J, Liu X, et al., “Learning with constraint learning: New perspective, solution strategy and
various applications,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2024.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables using the given feed dictionaries and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization, including training data, targets, and other information required for the LL objective computation.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization, including validation data, targets, and other information required for the UL objective computation.
auxiliary_model (_MonkeyPatchBase) – A patched lower-level model wrapped by the higher library, enabling differentiable optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.
hyper_gradient_finished (bool, optional) – A flag indicating whether the hypergradient computation is finished, by default False.
next_operation (str, optional) – The next operator for hypergradient calculation. Not supported in this implementation, by default None.
**kwargs (dict) –
Additional arguments, such as:
- lower_model_paramsList[jt.Var]
List of parameters for the lower-level model.
- Returns:
A dictionary containing:
- upper_lossjt.Var
The upper-level objective value after optimization.
- hyper_gradient_finishedbool
Indicates whether the hypergradient computation is complete.
- Return type:
Dict
Notes
This implementation calculates the Gauss-Newton (GN) loss to refine the gradients using second-order approximations.
If dynamic_initialization is enabled, the gradients of the lower-level variables are updated with time-dependent parameters.
Updates are performed on both lower-level and upper-level variables using computed gradients.
- Returns:
A dictionary containing the upper-level objective and the status of hypergradient computation.
- Return type:
Dict
jboat.na_ol.ns
- class jboat.na_ol.ns.NS(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientCalculation of the hyper gradient of the upper-level variables with Neumann Series (NS) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) –
Dictionary containing solver configurations, including:
gm_op (str): Indicates dynamic initialization type (e.g., “DI”).
lower_level_opt (Optimizer): Lower-level optimizer configuration.
- CG (Dict): Conjugate Gradient-specific parameters:
tolerance (float): Tolerance for convergence.
k (int): Number of iterations for Neumann approximation.
GDA-specific parameters, such as alpha_init and alpha_decay.
gda_loss (Callable, optional): Custom loss function for GDA.
References
- [1] J. Lorraine, P. Vicol, and D. Duvenaud, “Optimizing millions of hyperparameters
by implicit differentiation,” in AISTATS, 2020.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation. Defaults to 0.
next_operation (str, optional) – The next operator for the calculation of the hypergradient. Defaults to None.
hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished. Defaults to False.
- Returns:
A dictionary containing the upper-level objective and the status of hypergradient computation.
- Return type:
Dict
jboat.na_ol.ptt
- class jboat.na_ol.ptt.PTT(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Pessimistic Trajectory Truncation (PTT) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including: - “na_op” (List[str]): Indicates if PTT is used in the hyper-gradient operations.
References
[1] Liu R., Liu Y., Zeng S., et al. “Towards gradient-based bilevel optimization with non-convex followers and beyond,” in NeurIPS, 2021.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.
next_operation (str, optional) – The next operator for the calculation of the hypergradient, by default None.
hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished, by default False.
- Returns:
A dictionary containing updated feed_dict, auxiliary model, and gradient computation results.
- Return type:
Dict
jboat.na_ol.rad
- class jboat.na_ol.rad.RAD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Reverse Auto Differentiation (RAD) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including optional dynamic operation settings.
References
[1] Franceschi, Luca, et al. “Forward and reverse gradient-based hyperparameter optimization.” in ICML, 2017.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables using the provided data and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. Typically includes training data, targets, and other information required to compute the lower-level objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. Typically includes validation data, targets, and other information required to compute the upper-level objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.
next_operation (str, optional) – The next operator for the calculation of the hypergradient. Default is None.
**kwargs (dict) – Additional keyword arguments passed to the method.
- Returns:
A dictionary containing the upper-level objective and the status of hypergradient computation.
- Return type:
Dict
jboat.na_ol.rgt
- class jboat.na_ol.rgt.RGT(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]
Bases:
HyperGradientComputes the hyper-gradient of the upper-level variables using Reverse Gradient Truncation (RGT) [1].
- Parameters:
ll_objective (Callable) – The lower-level objective function of the BLO problem.
ul_objective (Callable) – The upper-level objective function of the BLO problem.
ll_model (jittor.Module) – The lower-level model of the BLO problem.
ul_model (jittor.Module) – The upper-level model of the BLO problem.
ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.
ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.
solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including the hyper-gradient operations and truncation settings.
References
[1] Shaban A., Cheng C.A., Hatch N., et al. “Truncated back-propagation for bilevel optimization,” in AISTATS, 2019.
- compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]
Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.
- Parameters:
ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.
ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.
auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.
max_loss_iter (int) – The number of iterations used for backpropagation.
next_operation (str) – The next operator for the calculation of the hypergradient.
hyper_gradient_finished (bool) – A boolean flag indicating whether the hypergradient computation is finished.
- Returns:
The current upper-level objective.
- Return type:
Any