Задача C. Nesterov Accelerated Gradient

Входной файл:Стандартный вход   Ограничение времени:1 сек
Выходной файл:Стандартный выход   Ограничение памяти:512 Мб

Условие

Требуется реализовать на языке Python класс NesterovAG, который описывает алгоритм ускоренного градиента Нестерова и имеет следующий интерфейс


import numpy as np

class NesterovAG:
    '''Represents a Nesterov Accelerated Gradient optimizer

    Fields:
        eta: learning rate
        alpha: exponential decay factor
    '''

    eta: float
    alpha: float

    def __init__(self, *, alpha: float = 0.9, eta: float = 0.1):
        '''Initalizes `eta` and `aplha` fields'''
        raise NotImplementedError()

    def optimize(self, oracle: Oracle, x0: np.ndarray, *,
                 max_iter: int = 100, eps: float = 1e-5) -> np.ndarray:
        '''Optimizes a function specified as `oracle` starting from point `x0`.
        The optimizations stops when `max_iter` iterations were completed or 
        the L2-norm of the current gradient is less than `eps`

        Args:
            oracle: function to optimize
            x0: point to start from
            max_iter: maximal number of iterations
            eps: threshold for L2-norm of gradient

        Returns:
            A point at which the optimization stopped
        '''
        raise NotImplementedError()
Параметрами алгоритма являются: Параметрами процесса оптимизации являются: Оптимизация останавливается при достижении max_iter количества итераций или при достижении точки, в которой L2 норма градиента меньше eps.

Класс Oracle описывает оптимизируемую функцию


import numpy as np

class Oracle:
    '''Provides an interface for evaluating a function and its derivative at arbitrary point'''
    
    def value(self, x: np.ndarray) -> float:
        '''Evaluates the underlying function at point `x`

        Args:
            x: a point to evaluate funciton at

        Returns:
            Function value
        '''
        raise NotImplementedError()
        
    def gradient(self, x: np.ndarray) -> np.ndarray:
        '''Evaluates the underlying function derivative at point `x`

        Args:
            x: a point to evaluate derivative at

        Returns:
            Function derivative
        '''
        raise NotImplementedError()

Формат выходных данных

Код решения должен содержать только определение и реализацию класса.


0.152s 0.014s 15