Входной файл: | Стандартный вход | Ограничение времени: | 1 сек | |
Выходной файл: | Стандартный выход | Ограничение памяти: | 512 Мб |
Требуется реализовать на языке Python класс Adam
, который описывает одноименный алгоритм и имеет следующий интерфейс
import numpy as np
class Adam:
'''Represents an Adam optimizer
Fields:
eta: learning rate
beta1: first moment decay rate
beta2: second moment decay rate
epsilon: smoothing term
'''
eta: float
beta1: float
beta2: float
epsilon: float
def __init__(self, *, eta: float = 0.1, beta1: float = 0.9, beta2: float = 0.999, epsilon: float = 1e-8):
'''Initalizes `eta`, `beta1` and `beta2` fields'''
raise NotImplementedError()
def optimize(self, oracle: Oracle, x0: np.ndarray, *,
max_iter: int = 100, eps: float = 1e-5) -> np.ndarray:
'''Optimizes a function specified as `oracle` starting from point `x0`.
The optimizations stops when `max_iter` iterations were completed or
the L2-norm of the gradient at current point is less than `eps`
Args:
oracle: function to optimize
x0: point to start from
max_iter: maximal number of iterations
eps: threshold for L2-norm of gradient
Returns:
A point at which the optimization stopped
'''
raise NotImplementedError()
Параметрами алгоритма являются:
eta
— learning rate,beta1
— коэффициент затухания первого момента,beta2
— коэффициент затухания второго момента,epsilon
— сглаживающий коэффициент.oracle
— оптимизируемая функция,x0
— начальная точка,max_iter
— максимальное количество итераций,eps
— пороговое значение L2 нормы градиента.max_iter
количества итераций или при достижении точки, в которой L2 норма градиента меньше eps
.
Класс Oracle
описывает оптимизируемую функцию
import numpy as np
class Oracle:
'''Provides an interface for evaluating a function and its derivative at arbitrary point'''
def value(self, x: np.ndarray) -> float:
'''Evaluates the underlying function at point `x`
Args:
x: a point to evaluate funciton at
Returns:
Function value
'''
raise NotImplementedError()
def gradient(self, x: np.ndarray) -> np.ndarray:
'''Evaluates the underlying function derivative at point `x`
Args:
x: a point to evaluate derivative at
Returns:
Function derivative
'''
raise NotImplementedError()
Код решения должен содержать только определение и реализацию класса.