Source code for pyacs.lib.glinalg.extract_block_diag
"""Extract block diagonal from a matrix."""
import numpy as np
[docs]
def extract_block_diag(a, n, k=0):
"""Extract block diagonal from a 2D array.
Parameters
----------
a : array_like
2D array.
n : int
Block size.
k : int, optional
Diagonal offset. Default is 0.
Returns
-------
numpy.ndarray
Block diagonal elements as 3D array of shape (n_blocks, n, n).
"""
a = np.asarray(a)
if a.ndim != 2:
raise ValueError("Only 2-D arrays handled")
if not (n > 0):
raise ValueError("Must have n >= 0")
if k > 0:
a = a[:, n * k:]
else:
a = a[-n * k:]
n_blocks = min(a.shape[0] // n, a.shape[1] // n)
new_shape = (n_blocks, n, n)
new_strides = (n * a.strides[0] + n * a.strides[1],
a.strides[0], a.strides[1])
return np.lib.stride_tricks.as_strided(a, new_shape, new_strides)