Задача 2B. Gradient Descent with Momentum

Входной файл:Стандартный вход   Ограничение времени:1 сек
Выходной файл:Стандартный выход   Ограничение памяти:512 Мб
Максимальный балл:1  

Условие

Требуется реализовать на языке Python класс GDM, который описывает алгоритм градиентного спуска с моментом и имеет следующий интерфейс:

import numpy as np class GDM: '''Represents a Gradient Descent with Momentum optimizer Fields: eta: learning rate alpha: exponential decay factor ''' eta: float alpha: float def __init__(self, *, alpha: float = 0.9, eta: float = 0.1): '''Initalizes `eta` and `alpha` fields''' raise NotImplementedError() def optimize(self, oracle: Oracle, x0: np.ndarray, *, max_iter: int = 100, eps: float = 1e-5) -> np.ndarray: '''Optimizes a function specified as `oracle` starting from point `x0`. The optimizations stops when `max_iter` iterations were completed or the L2-norm of the gradient at current point is less than `eps` Args: oracle: function to optimize x0: point to start from max_iter: maximal number of iterations eps: threshold for L2-norm of gradient Returns: A point at which the optimization stopped ''' raise NotImplementedError()

Параметрами алгоритма являются:

Параметрами процесса оптимизации являются: Оптимизация останавливается при достижении max_iter количества итераций или при достижении точки, в которой L2 норма градиента меньше eps.

Класс Oracle описывает оптимизируемую функцию:

import numpy as np class Oracle: '''Provides an interface for evaluating a function and its derivative at arbitrary point''' def value(self, x: np.ndarray) -> float: '''Evaluates the underlying function at point `x` Args: x: a point to evaluate funciton at Returns: Function value ''' raise NotImplementedError() def gradient(self, x: np.ndarray) -> np.ndarray: '''Evaluates the underlying function derivative at point `x` Args: x: a point to evaluate derivative at Returns: Function derivative ''' raise NotImplementedError()

Формат выходных данных

Код решения должен содержать только определение и реализацию класса.


0.182s 0.015s 17