Source code for jboat.na_ol.hyper_gradient

from abc import abstractmethod
from jboat.operation_registry import register_class

importlib = __import__("importlib")


[docs] @register_class class HyperGradient(object): """ Base class for computing hyper-gradients of upper-level variables in bilevel optimization problems. This class provides an abstract interface for hyper-gradient computation that can be extended for specific methods such as Conjugate Gradient, Finite Differentiation, or First-Order Approximation. Parameters ---------- ll_objective : callable The lower-level objective function of the bilevel optimization problem. ul_objective : callable The upper-level objective function of the bilevel optimization problem. ul_model : jittor.Module The upper-level model of the bilevel optimization problem. ll_model : jittor.Module The lower-level model of the bilevel optimization problem. ll_var : List[jittor.Var] A list of variables optimized with the lower-level objective. ul_var : List[jittor.Var] A list of variables optimized with the upper-level objective. solver_config : dict Dictionary containing configurations for the solver. """ def __init__( self, ll_objective, ul_objective, ul_model, ll_model, ll_var, ul_var, solver_config, ): self.ll_objective = ll_objective self.ul_objective = ul_objective self.ul_model = ul_model self.ll_model = ll_model self.ll_var = ll_var self.ul_var = ul_var self.solver_configs = solver_config
[docs] @abstractmethod def compute_gradients(self, **kwargs): pass