| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658 |
- """
- Loss functions for linear models with raw_prediction = X @ coef
- """
- import numpy as np
- from scipy import sparse
- from ..utils.extmath import squared_norm
- class LinearModelLoss:
- """General class for loss functions with raw_prediction = X @ coef + intercept.
- Note that raw_prediction is also known as linear predictor.
- The loss is the sum of per sample losses and includes a term for L2
- regularization::
- loss = sum_i s_i loss(y_i, X_i @ coef + intercept)
- + 1/2 * l2_reg_strength * ||coef||_2^2
- with sample weights s_i=1 if sample_weight=None.
- Gradient and hessian, for simplicity without intercept, are::
- gradient = X.T @ loss.gradient + l2_reg_strength * coef
- hessian = X.T @ diag(loss.hessian) @ X + l2_reg_strength * identity
- Conventions:
- if fit_intercept:
- n_dof = n_features + 1
- else:
- n_dof = n_features
- if base_loss.is_multiclass:
- coef.shape = (n_classes, n_dof) or ravelled (n_classes * n_dof,)
- else:
- coef.shape = (n_dof,)
- The intercept term is at the end of the coef array:
- if base_loss.is_multiclass:
- if coef.shape (n_classes, n_dof):
- intercept = coef[:, -1]
- if coef.shape (n_classes * n_dof,)
- intercept = coef[n_features::n_dof] = coef[(n_dof-1)::n_dof]
- intercept.shape = (n_classes,)
- else:
- intercept = coef[-1]
- Note: If coef has shape (n_classes * n_dof,), the 2d-array can be reconstructed as
- coef.reshape((n_classes, -1), order="F")
- The option order="F" makes coef[:, i] contiguous. This, in turn, makes the
- coefficients without intercept, coef[:, :-1], contiguous and speeds up
- matrix-vector computations.
- Note: If the average loss per sample is wanted instead of the sum of the loss per
- sample, one can simply use a rescaled sample_weight such that
- sum(sample_weight) = 1.
- Parameters
- ----------
- base_loss : instance of class BaseLoss from sklearn._loss.
- fit_intercept : bool
- """
- def __init__(self, base_loss, fit_intercept):
- self.base_loss = base_loss
- self.fit_intercept = fit_intercept
- def init_zero_coef(self, X, dtype=None):
- """Allocate coef of correct shape with zeros.
- Parameters:
- -----------
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- dtype : data-type, default=None
- Overrides the data type of coef. With dtype=None, coef will have the same
- dtype as X.
- Returns
- -------
- coef : ndarray of shape (n_dof,) or (n_classes, n_dof)
- Coefficients of a linear model.
- """
- n_features = X.shape[1]
- n_classes = self.base_loss.n_classes
- if self.fit_intercept:
- n_dof = n_features + 1
- else:
- n_dof = n_features
- if self.base_loss.is_multiclass:
- coef = np.zeros_like(X, shape=(n_classes, n_dof), dtype=dtype, order="F")
- else:
- coef = np.zeros_like(X, shape=n_dof, dtype=dtype)
- return coef
- def weight_intercept(self, coef):
- """Helper function to get coefficients and intercept.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- Returns
- -------
- weights : ndarray of shape (n_features,) or (n_classes, n_features)
- Coefficients without intercept term.
- intercept : float or ndarray of shape (n_classes,)
- Intercept terms.
- """
- if not self.base_loss.is_multiclass:
- if self.fit_intercept:
- intercept = coef[-1]
- weights = coef[:-1]
- else:
- intercept = 0.0
- weights = coef
- else:
- # reshape to (n_classes, n_dof)
- if coef.ndim == 1:
- weights = coef.reshape((self.base_loss.n_classes, -1), order="F")
- else:
- weights = coef
- if self.fit_intercept:
- intercept = weights[:, -1]
- weights = weights[:, :-1]
- else:
- intercept = 0.0
- return weights, intercept
- def weight_intercept_raw(self, coef, X):
- """Helper function to get coefficients, intercept and raw_prediction.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- Returns
- -------
- weights : ndarray of shape (n_features,) or (n_classes, n_features)
- Coefficients without intercept term.
- intercept : float or ndarray of shape (n_classes,)
- Intercept terms.
- raw_prediction : ndarray of shape (n_samples,) or \
- (n_samples, n_classes)
- """
- weights, intercept = self.weight_intercept(coef)
- if not self.base_loss.is_multiclass:
- raw_prediction = X @ weights + intercept
- else:
- # weights has shape (n_classes, n_dof)
- raw_prediction = X @ weights.T + intercept # ndarray, likely C-contiguous
- return weights, intercept, raw_prediction
- def l2_penalty(self, weights, l2_reg_strength):
- """Compute L2 penalty term l2_reg_strength/2 *||w||_2^2."""
- norm2_w = weights @ weights if weights.ndim == 1 else squared_norm(weights)
- return 0.5 * l2_reg_strength * norm2_w
- def loss(
- self,
- coef,
- X,
- y,
- sample_weight=None,
- l2_reg_strength=0.0,
- n_threads=1,
- raw_prediction=None,
- ):
- """Compute the loss as sum over point-wise losses.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- y : contiguous array of shape (n_samples,)
- Observed, true target values.
- sample_weight : None or contiguous array of shape (n_samples,), default=None
- Sample weights.
- l2_reg_strength : float, default=0.0
- L2 regularization strength
- n_threads : int, default=1
- Number of OpenMP threads to use.
- raw_prediction : C-contiguous array of shape (n_samples,) or array of \
- shape (n_samples, n_classes)
- Raw prediction values (in link space). If provided, these are used. If
- None, then raw_prediction = X @ coef + intercept is calculated.
- Returns
- -------
- loss : float
- Sum of losses per sample plus penalty.
- """
- if raw_prediction is None:
- weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
- else:
- weights, intercept = self.weight_intercept(coef)
- loss = self.base_loss.loss(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- loss = loss.sum()
- return loss + self.l2_penalty(weights, l2_reg_strength)
- def loss_gradient(
- self,
- coef,
- X,
- y,
- sample_weight=None,
- l2_reg_strength=0.0,
- n_threads=1,
- raw_prediction=None,
- ):
- """Computes the sum of loss and gradient w.r.t. coef.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- y : contiguous array of shape (n_samples,)
- Observed, true target values.
- sample_weight : None or contiguous array of shape (n_samples,), default=None
- Sample weights.
- l2_reg_strength : float, default=0.0
- L2 regularization strength
- n_threads : int, default=1
- Number of OpenMP threads to use.
- raw_prediction : C-contiguous array of shape (n_samples,) or array of \
- shape (n_samples, n_classes)
- Raw prediction values (in link space). If provided, these are used. If
- None, then raw_prediction = X @ coef + intercept is calculated.
- Returns
- -------
- loss : float
- Sum of losses per sample plus penalty.
- gradient : ndarray of shape coef.shape
- The gradient of the loss.
- """
- n_features, n_classes = X.shape[1], self.base_loss.n_classes
- n_dof = n_features + int(self.fit_intercept)
- if raw_prediction is None:
- weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
- else:
- weights, intercept = self.weight_intercept(coef)
- loss, grad_pointwise = self.base_loss.loss_gradient(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- loss = loss.sum()
- loss += self.l2_penalty(weights, l2_reg_strength)
- if not self.base_loss.is_multiclass:
- grad = np.empty_like(coef, dtype=weights.dtype)
- grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
- if self.fit_intercept:
- grad[-1] = grad_pointwise.sum()
- else:
- grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
- # grad_pointwise.shape = (n_samples, n_classes)
- grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
- if self.fit_intercept:
- grad[:, -1] = grad_pointwise.sum(axis=0)
- if coef.ndim == 1:
- grad = grad.ravel(order="F")
- return loss, grad
- def gradient(
- self,
- coef,
- X,
- y,
- sample_weight=None,
- l2_reg_strength=0.0,
- n_threads=1,
- raw_prediction=None,
- ):
- """Computes the gradient w.r.t. coef.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- y : contiguous array of shape (n_samples,)
- Observed, true target values.
- sample_weight : None or contiguous array of shape (n_samples,), default=None
- Sample weights.
- l2_reg_strength : float, default=0.0
- L2 regularization strength
- n_threads : int, default=1
- Number of OpenMP threads to use.
- raw_prediction : C-contiguous array of shape (n_samples,) or array of \
- shape (n_samples, n_classes)
- Raw prediction values (in link space). If provided, these are used. If
- None, then raw_prediction = X @ coef + intercept is calculated.
- Returns
- -------
- gradient : ndarray of shape coef.shape
- The gradient of the loss.
- """
- n_features, n_classes = X.shape[1], self.base_loss.n_classes
- n_dof = n_features + int(self.fit_intercept)
- if raw_prediction is None:
- weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
- else:
- weights, intercept = self.weight_intercept(coef)
- grad_pointwise = self.base_loss.gradient(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- if not self.base_loss.is_multiclass:
- grad = np.empty_like(coef, dtype=weights.dtype)
- grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
- if self.fit_intercept:
- grad[-1] = grad_pointwise.sum()
- return grad
- else:
- grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
- # gradient.shape = (n_samples, n_classes)
- grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
- if self.fit_intercept:
- grad[:, -1] = grad_pointwise.sum(axis=0)
- if coef.ndim == 1:
- return grad.ravel(order="F")
- else:
- return grad
- def gradient_hessian(
- self,
- coef,
- X,
- y,
- sample_weight=None,
- l2_reg_strength=0.0,
- n_threads=1,
- gradient_out=None,
- hessian_out=None,
- raw_prediction=None,
- ):
- """Computes gradient and hessian w.r.t. coef.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- y : contiguous array of shape (n_samples,)
- Observed, true target values.
- sample_weight : None or contiguous array of shape (n_samples,), default=None
- Sample weights.
- l2_reg_strength : float, default=0.0
- L2 regularization strength
- n_threads : int, default=1
- Number of OpenMP threads to use.
- gradient_out : None or ndarray of shape coef.shape
- A location into which the gradient is stored. If None, a new array
- might be created.
- hessian_out : None or ndarray
- A location into which the hessian is stored. If None, a new array
- might be created.
- raw_prediction : C-contiguous array of shape (n_samples,) or array of \
- shape (n_samples, n_classes)
- Raw prediction values (in link space). If provided, these are used. If
- None, then raw_prediction = X @ coef + intercept is calculated.
- Returns
- -------
- gradient : ndarray of shape coef.shape
- The gradient of the loss.
- hessian : ndarray
- Hessian matrix.
- hessian_warning : bool
- True if pointwise hessian has more than half of its elements non-positive.
- """
- n_samples, n_features = X.shape
- n_dof = n_features + int(self.fit_intercept)
- if raw_prediction is None:
- weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
- else:
- weights, intercept = self.weight_intercept(coef)
- grad_pointwise, hess_pointwise = self.base_loss.gradient_hessian(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- # For non-canonical link functions and far away from the optimum, the pointwise
- # hessian can be negative. We take care that 75% of the hessian entries are
- # positive.
- hessian_warning = np.mean(hess_pointwise <= 0) > 0.25
- hess_pointwise = np.abs(hess_pointwise)
- if not self.base_loss.is_multiclass:
- # gradient
- if gradient_out is None:
- grad = np.empty_like(coef, dtype=weights.dtype)
- else:
- grad = gradient_out
- grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
- if self.fit_intercept:
- grad[-1] = grad_pointwise.sum()
- # hessian
- if hessian_out is None:
- hess = np.empty(shape=(n_dof, n_dof), dtype=weights.dtype)
- else:
- hess = hessian_out
- if hessian_warning:
- # Exit early without computing the hessian.
- return grad, hess, hessian_warning
- # TODO: This "sandwich product", X' diag(W) X, is the main computational
- # bottleneck for solvers. A dedicated Cython routine might improve it
- # exploiting the symmetry (as opposed to, e.g., BLAS gemm).
- if sparse.issparse(X):
- hess[:n_features, :n_features] = (
- X.T
- @ sparse.dia_matrix(
- (hess_pointwise, 0), shape=(n_samples, n_samples)
- )
- @ X
- ).toarray()
- else:
- # np.einsum may use less memory but the following, using BLAS matrix
- # multiplication (gemm), is by far faster.
- WX = hess_pointwise[:, None] * X
- hess[:n_features, :n_features] = np.dot(X.T, WX)
- if l2_reg_strength > 0:
- # The L2 penalty enters the Hessian on the diagonal only. To add those
- # terms, we use a flattened view on the array.
- hess.reshape(-1)[
- : (n_features * n_dof) : (n_dof + 1)
- ] += l2_reg_strength
- if self.fit_intercept:
- # With intercept included as added column to X, the hessian becomes
- # hess = (X, 1)' @ diag(h) @ (X, 1)
- # = (X' @ diag(h) @ X, X' @ h)
- # ( h @ X, sum(h))
- # The left upper part has already been filled, it remains to compute
- # the last row and the last column.
- Xh = X.T @ hess_pointwise
- hess[:-1, -1] = Xh
- hess[-1, :-1] = Xh
- hess[-1, -1] = hess_pointwise.sum()
- else:
- # Here we may safely assume HalfMultinomialLoss aka categorical
- # cross-entropy.
- raise NotImplementedError
- return grad, hess, hessian_warning
- def gradient_hessian_product(
- self, coef, X, y, sample_weight=None, l2_reg_strength=0.0, n_threads=1
- ):
- """Computes gradient and hessp (hessian product function) w.r.t. coef.
- Parameters
- ----------
- coef : ndarray of shape (n_dof,), (n_classes, n_dof) or (n_classes * n_dof,)
- Coefficients of a linear model.
- If shape (n_classes * n_dof,), the classes of one feature are contiguous,
- i.e. one reconstructs the 2d-array via
- coef.reshape((n_classes, -1), order="F").
- X : {array-like, sparse matrix} of shape (n_samples, n_features)
- Training data.
- y : contiguous array of shape (n_samples,)
- Observed, true target values.
- sample_weight : None or contiguous array of shape (n_samples,), default=None
- Sample weights.
- l2_reg_strength : float, default=0.0
- L2 regularization strength
- n_threads : int, default=1
- Number of OpenMP threads to use.
- Returns
- -------
- gradient : ndarray of shape coef.shape
- The gradient of the loss.
- hessp : callable
- Function that takes in a vector input of shape of gradient and
- and returns matrix-vector product with hessian.
- """
- (n_samples, n_features), n_classes = X.shape, self.base_loss.n_classes
- n_dof = n_features + int(self.fit_intercept)
- weights, intercept, raw_prediction = self.weight_intercept_raw(coef, X)
- if not self.base_loss.is_multiclass:
- grad_pointwise, hess_pointwise = self.base_loss.gradient_hessian(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- grad = np.empty_like(coef, dtype=weights.dtype)
- grad[:n_features] = X.T @ grad_pointwise + l2_reg_strength * weights
- if self.fit_intercept:
- grad[-1] = grad_pointwise.sum()
- # Precompute as much as possible: hX, hX_sum and hessian_sum
- hessian_sum = hess_pointwise.sum()
- if sparse.issparse(X):
- hX = (
- sparse.dia_matrix((hess_pointwise, 0), shape=(n_samples, n_samples))
- @ X
- )
- else:
- hX = hess_pointwise[:, np.newaxis] * X
- if self.fit_intercept:
- # Calculate the double derivative with respect to intercept.
- # Note: In case hX is sparse, hX.sum is a matrix object.
- hX_sum = np.squeeze(np.asarray(hX.sum(axis=0)))
- # prevent squeezing to zero-dim array if n_features == 1
- hX_sum = np.atleast_1d(hX_sum)
- # With intercept included and l2_reg_strength = 0, hessp returns
- # res = (X, 1)' @ diag(h) @ (X, 1) @ s
- # = (X, 1)' @ (hX @ s[:n_features], sum(h) * s[-1])
- # res[:n_features] = X' @ hX @ s[:n_features] + sum(h) * s[-1]
- # res[-1] = 1' @ hX @ s[:n_features] + sum(h) * s[-1]
- def hessp(s):
- ret = np.empty_like(s)
- if sparse.issparse(X):
- ret[:n_features] = X.T @ (hX @ s[:n_features])
- else:
- ret[:n_features] = np.linalg.multi_dot([X.T, hX, s[:n_features]])
- ret[:n_features] += l2_reg_strength * s[:n_features]
- if self.fit_intercept:
- ret[:n_features] += s[-1] * hX_sum
- ret[-1] = hX_sum @ s[:n_features] + hessian_sum * s[-1]
- return ret
- else:
- # Here we may safely assume HalfMultinomialLoss aka categorical
- # cross-entropy.
- # HalfMultinomialLoss computes only the diagonal part of the hessian, i.e.
- # diagonal in the classes. Here, we want the matrix-vector product of the
- # full hessian. Therefore, we call gradient_proba.
- grad_pointwise, proba = self.base_loss.gradient_proba(
- y_true=y,
- raw_prediction=raw_prediction,
- sample_weight=sample_weight,
- n_threads=n_threads,
- )
- grad = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
- grad[:, :n_features] = grad_pointwise.T @ X + l2_reg_strength * weights
- if self.fit_intercept:
- grad[:, -1] = grad_pointwise.sum(axis=0)
- # Full hessian-vector product, i.e. not only the diagonal part of the
- # hessian. Derivation with some index battle for input vector s:
- # - sample index i
- # - feature indices j, m
- # - class indices k, l
- # - 1_{k=l} is one if k=l else 0
- # - p_i_k is the (predicted) probability that sample i belongs to class k
- # for all i: sum_k p_i_k = 1
- # - s_l_m is input vector for class l and feature m
- # - X' = X transposed
- #
- # Note: Hessian with dropping most indices is just:
- # X' @ p_k (1(k=l) - p_l) @ X
- #
- # result_{k j} = sum_{i, l, m} Hessian_{i, k j, m l} * s_l_m
- # = sum_{i, l, m} (X')_{ji} * p_i_k * (1_{k=l} - p_i_l)
- # * X_{im} s_l_m
- # = sum_{i, m} (X')_{ji} * p_i_k
- # * (X_{im} * s_k_m - sum_l p_i_l * X_{im} * s_l_m)
- #
- # See also https://github.com/scikit-learn/scikit-learn/pull/3646#discussion_r17461411 # noqa
- def hessp(s):
- s = s.reshape((n_classes, -1), order="F") # shape = (n_classes, n_dof)
- if self.fit_intercept:
- s_intercept = s[:, -1]
- s = s[:, :-1] # shape = (n_classes, n_features)
- else:
- s_intercept = 0
- tmp = X @ s.T + s_intercept # X_{im} * s_k_m
- tmp += (-proba * tmp).sum(axis=1)[:, np.newaxis] # - sum_l ..
- tmp *= proba # * p_i_k
- if sample_weight is not None:
- tmp *= sample_weight[:, np.newaxis]
- # hess_prod = empty_like(grad), but we ravel grad below and this
- # function is run after that.
- hess_prod = np.empty((n_classes, n_dof), dtype=weights.dtype, order="F")
- hess_prod[:, :n_features] = tmp.T @ X + l2_reg_strength * s
- if self.fit_intercept:
- hess_prod[:, -1] = tmp.sum(axis=0)
- if coef.ndim == 1:
- return hess_prod.ravel(order="F")
- else:
- return hess_prod
- if coef.ndim == 1:
- return grad.ravel(order="F"), hessp
- return grad, hessp
|