jboat.na_ol

Submodules

jboat.na_ol.hyper_gradient

class jboat.na_ol.hyper_gradient.HyperGradient(ll_objective, ul_objective, ul_model, ll_model, ll_var, ul_var, solver_config)[source]

Bases: object

Base class for computing hyper-gradients of upper-level variables in bilevel optimization problems.

This class provides an abstract interface for hyper-gradient computation that can be extended for specific methods such as Conjugate Gradient, Finite Differentiation, or First-Order Approximation.

Parameters:
  • ll_objective (callable) – The lower-level objective function of the bilevel optimization problem.

  • ul_objective (callable) – The upper-level objective function of the bilevel optimization problem.

  • ul_model (jittor.Module) – The upper-level model of the bilevel optimization problem.

  • ll_model (jittor.Module) – The lower-level model of the bilevel optimization problem.

  • ll_var (List[jittor.Var]) – A list of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – A list of variables optimized with the upper-level objective.

  • solver_config (dict) – Dictionary containing configurations for the solver.

abstract compute_gradients(**kwargs)[source]

jboat.na_ol.sequential_hg

class jboat.na_ol.sequential_hg.SequentialHG(ordered_instances, custom_order)[source]

Bases: object

A class for managing sequential hyper-gradient operations.

This class dynamically organizes and executes a sequence of hyper-gradient computations using user-defined and validated orders of gradient operators.

Parameters:
  • ordered_instances (List[object]) – A list of instantiated gradient operator objects, ordered as per the adjusted sequence.

  • custom_order (List[str]) – The user-defined order of gradient operators.

compute_gradients(**kwargs)[source]

Compute hyper-gradients sequentially using the ordered instances.

This method processes the hyper-gradients in the defined order, passing intermediate results between consecutive gradient operators.

Parameters:

**kwargs (dict) – Additional arguments required for gradient computations.

Returns:

A list of dictionaries containing results for each gradient operator.

Return type:

List[Dict]

jboat.na_ol.sequential_hg.makes_functional_na_operation(custom_order, **kwargs)[source]

Dynamically create a SequentialHG object with ordered gradient operators.

This function validates the user-defined operator order, adjusts it to conform with predefined gradient rules, and dynamically loads the corresponding operator classes.

Parameters:
  • custom_order (List[str]) – The user-defined order of gradient operators.

  • **kwargs (dict) – Additional arguments required for initializing gradient operator instances.

Returns:

An instance of SequentialHG containing the ordered gradient operators and result management.

Return type:

SequentialHG

jboat.na_ol.sequential_hg.validate_and_adjust_order(custom_order, gradient_order)[source]

Validate and adjust the custom order to match the predefined gradient order.

This function ensures that the user-defined order adheres to the predefined grouping rules and adjusts it accordingly.

Parameters:
  • custom_order (List[str]) – The user-provided order of gradient operators.

  • gradient_order (List[List[str]]) – The predefined order of gradient operator groups.

Returns:

Adjusted order of gradient operators following the predefined rules.

Return type:

List[str]

jboat.na_ol.cg

class jboat.na_ol.cg.CG(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Conjugate Gradient (CG) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) –

    Dictionary containing solver configurations. Expected keys include:

    • r (float): Perturbation radius for finite differences.

    • lower_level_opt (jittor.optim.Optimizer): Lower-level optimizer configuration.

    • gm_op (str): Indicates dynamic initialization type (e.g., “DI”).

    • GDA-specific parameters if applicable, such as:
      • alpha_init (float): Initial learning rate for GDA.

      • alpha_decay (float): Decay factor for GDA.

ll_lr

Learning rate for the lower-level optimizer, extracted from lower_level_opt.

Type:

float

dynamic_initialization

Indicates whether dynamic initialization is enabled (based on gm_op).

Type:

bool

tolerance

The tolerance for approximation.

Type:

float

K

Number of iterations for CG approximation.

Type:

int

alpha

Initial learning rate for GDA operations.

Type:

float

alpha_decay

Decay factor applied to the learning rate for GDA.

Type:

float

gda_loss

Custom loss function for GDA operations, if specified in solver_config.

Type:

Callable, optional

References

[1] Pedregosa F. “Hyperparameter optimization with approximate gradient,” in ICML, 2016.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.

  • hyper_gradient_finished (bool, optional) – A flag indicating whether the hyper-gradient computation is finished. Default is False.

  • next_operation (str, optional) – The next operator for the calculation of the hypergradient. Default is None.

  • **kwargs (dict) – Additional arguments, such as: - lower_model_params (list): Parameters of the lower-level model (default: list(auxiliary_model.parameters())). - hparams (list): Hyper-parameters of the upper-level model (default: list(self.ul_var)).

Returns:

A dictionary containing: - “upper_loss”: The current upper-level objective value. - “hyper_gradient_finished”: A boolean indicating that the hyper-gradient computation is complete.

Return type:

dict

Returns:

A dictionary containing the upper-level objective and the status of hypergradient computation.

Return type:

Dict

jboat.na_ol.fd

class jboat.na_ol.fd.FD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Finite Differentiation (FD) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) –

    Dictionary containing solver configurations. Expected keys include:

    • r (float): Perturbation radius for finite differences.

    • lower_level_opt (jittor.optim.Optimizer): Lower-level optimizer configuration.

    • gm_op (str): Indicates dynamic initialization type (e.g., “DI”).

    • GDA-specific parameters if applicable, such as:
      • alpha_init (float): Initial learning rate for GDA.

      • alpha_decay (float): Decay factor for GDA.

ll_lr

Learning rate for the lower-level optimizer, extracted from lower_level_opt.

Type:

float

dynamic_initialization

Indicates whether dynamic initialization is enabled (based on gm_op).

Type:

bool

_r

Perturbation radius for finite differences, used for gradient computation.

Type:

float

alpha

Initial learning rate for GDA operations.

Type:

float

alpha_decay

Decay factor applied to the learning rate for GDA.

Type:

float

gda_loss

Custom loss function for GDA operations, if specified in solver_config.

Type:

Callable, optional

References

[1] H. Liu, K. Simonyan, Y. Yang, “DARTS: Differentiable Architecture Search,” in ICLR, 2019.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.

  • hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hyper-gradient computation is finished. Default is False.

  • next_operation (str, optional) – The next operator for the calculation of the hyper-gradient. Default is None.

Returns:

A dictionary containing: - “upper_loss”: The current upper-level objective value. - “hyper_gradient_finished”: A boolean indicating whether the hyper-gradient computation is complete.

Return type:

dict

Raises:

AssertionError – If next_operation is not None, as FD does not support next_operation.

jboat.na_ol.foa

class jboat.na_ol.foa.FOA(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using First-Order Approximation (FOA) [1], leveraging Initialization-based Auto Differentiation (IAD) [2].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) – Dictionary containing solver configurations.

References

[1] Nichol A., “On first-order meta-learning algorithms,” arXiv preprint arXiv:1803.02999, 2018. [2] Finn C., Abbeel P., Levine S., “Model-agnostic meta-learning for fast adaptation of deep networks”, in ICML, 2017.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables using the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower-level model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.

  • hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished, by default False.

  • next_operation (str, optional) – The next operator for the calculation of the hypergradient, by default None.

  • kwargs (dict) – Additional keyword arguments.

Returns:

A dictionary containing information required for the next step in the hypergradient computation, including the feed dictionaries, auxiliary model, iteration count, and other optional arguments.

Return type:

Dict

Raises:

AssertionError – If next_operation is not defined or if hyper_gradient_finished is True.

jboat.na_ol.iad

class jboat.na_ol.iad.IAD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Initialization-based Auto Differentiation (IAD) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) – Dictionary containing solver configurations.

References

[1] Finn C., Abbeel P., Levine S., “Model-agnostic meta-learning for fast adaptation of deep networks”, in ICML, 2017.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int) – The number of iterations used for backpropagation.

  • next_operation (str) – The next operator for the calculation of the hypergradient.

  • hyper_gradient_finished (bool) – A boolean flag indicating whether the hypergradient computation is finished.

Returns:

A dictionary containing the upper-level objective and the status of hypergradient computation.

Return type:

Dict

jboat.na_ol.iga

class jboat.na_ol.iga.IGA(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Implicit Gradient Approximation (IGA) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) –

    Dictionary containing solver configurations, including:

    • alpha_init (float): Initial learning rate for GDA.

    • alpha_decay (float): Decay factor for the GDA learning rate.

    • Optional gda_loss (Callable): Custom loss function for GDA, if applicable.

    • gm_op (List[str]): Specifies dynamic operations, e.g., “DI” for dynamic initialization.

alpha

Initial learning rate for GDA operations, if applicable.

Type:

float

alpha_decay

Decay factor applied to the GDA learning rate.

Type:

float

gda_loss

Custom loss function for GDA operations, if specified in solver_config.

Type:

Callable, optional

dynamic_initialization

Indicates whether dynamic initialization is enabled, based on gm_op.

Type:

bool

References

[1] Liu R, Gao J, Liu X, et al., “Learning with constraint learning: New perspective, solution strategy and

various applications,” IEEE Transactions on Pattern Analysis and Machine Intelligence, 2024.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables using the given feed dictionaries and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization, including training data, targets, and other information required for the LL objective computation.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization, including validation data, targets, and other information required for the UL objective computation.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower-level model wrapped by the higher library, enabling differentiable optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.

  • hyper_gradient_finished (bool, optional) – A flag indicating whether the hypergradient computation is finished, by default False.

  • next_operation (str, optional) – The next operator for hypergradient calculation. Not supported in this implementation, by default None.

  • **kwargs (dict) –

    Additional arguments, such as:

    • lower_model_paramsList[jt.Var]

      List of parameters for the lower-level model.

Returns:

A dictionary containing:

  • upper_lossjt.Var

    The upper-level objective value after optimization.

  • hyper_gradient_finishedbool

    Indicates whether the hypergradient computation is complete.

Return type:

Dict

Notes

  • This implementation calculates the Gauss-Newton (GN) loss to refine the gradients using second-order approximations.

  • If dynamic_initialization is enabled, the gradients of the lower-level variables are updated with time-dependent parameters.

  • Updates are performed on both lower-level and upper-level variables using computed gradients.

Returns:

A dictionary containing the upper-level objective and the status of hypergradient computation.

Return type:

Dict

jboat.na_ol.ns

class jboat.na_ol.ns.NS(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Calculation of the hyper gradient of the upper-level variables with Neumann Series (NS) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) –

    Dictionary containing solver configurations, including:

    • gm_op (str): Indicates dynamic initialization type (e.g., “DI”).

    • lower_level_opt (Optimizer): Lower-level optimizer configuration.

    • CG (Dict): Conjugate Gradient-specific parameters:
      • tolerance (float): Tolerance for convergence.

      • k (int): Number of iterations for Neumann approximation.

    • GDA-specific parameters, such as alpha_init and alpha_decay.

    • gda_loss (Callable, optional): Custom loss function for GDA.

References

[1] J. Lorraine, P. Vicol, and D. Duvenaud, “Optimizing millions of hyperparameters

by implicit differentiation,” in AISTATS, 2020.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation. Defaults to 0.

  • next_operation (str, optional) – The next operator for the calculation of the hypergradient. Defaults to None.

  • hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished. Defaults to False.

Returns:

A dictionary containing the upper-level objective and the status of hypergradient computation.

Return type:

Dict

jboat.na_ol.ptt

class jboat.na_ol.ptt.PTT(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Pessimistic Trajectory Truncation (PTT) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including: - “na_op” (List[str]): Indicates if PTT is used in the hyper-gradient operations.

References

[1] Liu R., Liu Y., Zeng S., et al. “Towards gradient-based bilevel optimization with non-convex followers and beyond,” in NeurIPS, 2021.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation, by default 0.

  • next_operation (str, optional) – The next operator for the calculation of the hypergradient, by default None.

  • hyper_gradient_finished (bool, optional) – A boolean flag indicating whether the hypergradient computation is finished, by default False.

Returns:

A dictionary containing updated feed_dict, auxiliary model, and gradient computation results.

Return type:

Dict

jboat.na_ol.rad

class jboat.na_ol.rad.RAD(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Reverse Auto Differentiation (RAD) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including optional dynamic operation settings.

References

[1] Franceschi, Luca, et al. “Forward and reverse gradient-based hyperparameter optimization.” in ICML, 2017.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables using the provided data and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. Typically includes training data, targets, and other information required to compute the lower-level objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. Typically includes validation data, targets, and other information required to compute the upper-level objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int, optional) – The number of iterations used for backpropagation. Default is 0.

  • next_operation (str, optional) – The next operator for the calculation of the hypergradient. Default is None.

  • **kwargs (dict) – Additional keyword arguments passed to the method.

Returns:

A dictionary containing the upper-level objective and the status of hypergradient computation.

Return type:

Dict

jboat.na_ol.rgt

class jboat.na_ol.rgt.RGT(ll_objective, ul_objective, ll_model, ul_model, ll_var, ul_var, solver_config)[source]

Bases: HyperGradient

Computes the hyper-gradient of the upper-level variables using Reverse Gradient Truncation (RGT) [1].

Parameters:
  • ll_objective (Callable) – The lower-level objective function of the BLO problem.

  • ul_objective (Callable) – The upper-level objective function of the BLO problem.

  • ll_model (jittor.Module) – The lower-level model of the BLO problem.

  • ul_model (jittor.Module) – The upper-level model of the BLO problem.

  • ll_var (List[jittor.Var]) – List of variables optimized with the lower-level objective.

  • ul_var (List[jittor.Var]) – List of variables optimized with the upper-level objective.

  • solver_config (Dict[str, Any]) – Dictionary containing solver configurations, including the hyper-gradient operations and truncation settings.

References

[1] Shaban A., Cheng C.A., Hatch N., et al. “Truncated back-propagation for bilevel optimization,” in AISTATS, 2019.

compute_gradients(ll_feed_dict, ul_feed_dict, auxiliary_model, max_loss_iter=0, hyper_gradient_finished=False, next_operation=None, **kwargs)[source]

Compute the hyper-gradients of the upper-level variables with the data from feed_dict and patched models.

Parameters:
  • ll_feed_dict (Dict) – Dictionary containing the lower-level data used for optimization. It typically includes training data, targets, and other information required to compute the LL objective.

  • ul_feed_dict (Dict) – Dictionary containing the upper-level data used for optimization. It typically includes validation data, targets, and other information required to compute the UL objective.

  • auxiliary_model (_MonkeyPatchBase) – A patched lower model wrapped by the higher library. It serves as the lower-level model for optimization.

  • max_loss_iter (int) – The number of iterations used for backpropagation.

  • next_operation (str) – The next operator for the calculation of the hypergradient.

  • hyper_gradient_finished (bool) – A boolean flag indicating whether the hypergradient computation is finished.

Returns:

The current upper-level objective.

Return type:

Any