jnp.linalg.norm 是 JAX 中用于计算向量或矩阵的范数的函数。JAX 是一个用于高性能机器学习研究的 Python 库,它提供了与 NumPy 类似的 API,但支持自动微分和加速计算。jnp 是 JAX 的 NumPy 接口。
jnp.linalg.norm 的基本语法
jnp.linalg.norm(x, ord=None, axis=None, keepdims=False)
参数
- x:要计算范数的输入数组。可以是向量(1D 数组)或矩阵(2D 数组)。
- ord:指定要计算的范数的类型。可以是以下值之一:
None:默认的欧几里得范数(L2 范数)。1:L1 范数,向量元素绝对值之和。2:L2 范数,向量元素平方和的平方根。inf:最大范数,向量元素的最大绝对值。-inf:最小范数,向量元素的最小绝对值。- 对于矩阵,
ord可以是以下值之一:'fro'或None:Frobenius 范数(元素平方和的平方根)。1:列和范数(每列元素绝对值之和的最大值)。inf:行和范数(每行元素绝对值之和的最大值)。
- axis:指定沿哪个轴计算范数。如果为
None,则计算整个数组的范数。对于向量,可以是一个整数;对于矩阵,可以是一个长度为 2 的元组,指定计算的维度。 - keepdims:如果为
True,则在结果中保持原数组的维度。这对于保持与输入数组的形状一致性很有用。
返回值
返回计算后的范数值。如果 axis 为 None,则返回单个值;否则返回按指定轴计算的范数。
示例
计算向量的 L2 范数(默认)
import jax.numpy as jnpx = jnp.array([1, 2, 3])
l2_norm = jnp.linalg.norm(x)
print(l2_norm) # 输出: 3.7416573867739413
计算向量的 L1 范数
l1_norm = jnp.linalg.norm(x, ord=1)
print(l1_norm) # 输出: 6.0
计算矩阵的 Frobenius 范数
A = jnp.array([[1, 2, 3], [4, 5, 6]])
frobenius_norm = jnp.linalg.norm(A)
print(frobenius_norm) # 输出: 9.539392014169456
计算矩阵的列和范数
column_sum_norm = jnp.linalg.norm(A, ord=1)
print(column_sum_norm) # 输出: 9.0
计算矩阵的行和范数
row_sum_norm = jnp.linalg.norm(A, ord=jnp.inf)
print(row_sum_norm) # 输出: 15.0
沿指定轴计算范数
计算每列的 L2 范数:
column_l2_norms = jnp.linalg.norm(A, axis=0)
print(column_l2_norms) # 输出: [4.1231055 5.3851647 6.708204]
计算每行的 L2 范数:
row_l2_norms = jnp.linalg.norm(A, axis=1)
print(row_l2_norms) # 输出: [ 3.7416575 8.774964 ]
总结
jnp.linalg.norm 是一个强大且灵活的工具,用于计算向量和矩阵的各种范数。通过指定不同的 ord 和 axis 参数,可以计算出不同类型和不同轴上的范数。