基本数据类型

Python3的六个标准数据类型中:

  • 不可变数据:Number(数字)、String(字符串)、Tuple(元组)
  • 可变数据:List(列表)、Dictionary(字典)、Set(集合)

不可变指的是创建后内存内容不能再被修改,只能创建新的对象来修改

数字Number

Python 支持三种不同的数值类型:

  • 整型(int),Python3整型是没有限制大小的,可以当作 Long 类型使用,布尔(bool)是整型的子类型
  • 浮点型(float)
  • 复数(complex),可以用a+bj,或者complex(a,b)表示,复数的实部a和虚部b都是浮点型

数据类型的转换只需要将数据类型作为函数名即可

在混合计算时会把整型转换成为浮点数,+,-,*/和其它语言里一样

数值的除法包含两个运算符:/返回一个浮点数,//返回一个整数(但不一定是int,和分子分母的数据类型有关)

数学常量:pi and e

字符串String

用单引号'或双引号"括起来,同时使用反斜杠\转义特殊字符,注意字符串不可变

字符串的切片:变量[头下标:尾下标],索引值以0为开始值,-1为从末尾的开始位置

可以用+运算符连接在一起,用*运算符重复

格式化推荐用f-string(python3.6之后)

f开头,后面跟着字符串,字符串中的表达式用大括号 {} 包起来,它会将变量或表达式计算后的值替换进去

用这种方式就不用再去判断使用%s还是%d

列表List

列表是Python中使用最频繁的数据类型,写在方括号[]之间,元素用逗号隔开

和字符串一样有索引和切片,不同的是列表中的元素是可以改变的

常用方法:(部分和c++类似)

方法 描述
len(list) 列表元素个数
list.count(obj) 统计某个元素在列表中出现的次数
list.index(obj) 从列表中找出某个值第一个匹配项的索引位置
list.append(obj) 在列表末尾添加新的对象
list.insert(index, obj) 将对象插入列表
list.extend(seq) 在列表末尾一次性追加另一个序列中的多个值
list.pop([index=-1]) 移除列表中的一个元素(默认最后一个),并且返回该元素的值
list.remove(obj) 移除列表中某个值的第一个匹配项
list.sort(reverse=False) 对原列表进行排序
list.reverse() 反向列表中元素

输出小技巧:

1
2
3
4
5
6
7
8
9
10
matrix = [
[1, 2, 3],
[4, 5, 6],
[7, 8, 9]
]
# 直接打印(默认输出)
print(matrix)
# 美化输出(去除逗号和方括号)
for row in matrix:
print(' '.join(str(x) for x in row))

元组Tuple

与列表类似,不同之处在于元组的元素不能修改,写在小括号()里,元素之间用逗号隔开

和列表一样有索引和切片,不同之处在于元组的元素不能修改,且数据类型可以不同

元组中的元素值不允许删除,但可以使用del语句来删除整个元组

常用方法:(方法极少)

方法 描述
len(tuple) 计算元组元素个数
max(tuple)/min(tuple) 返回元组中元素最大/最小值

集合Set

集合(Set)是一个无序、可变、不重复的元素集合,使用大括号{}表示,元素之间用逗号分隔

3.7+中含存储顺序,但逻辑仍无序

可以进行交集、并集、差集等常见的集合操作(&,|,-)

创建一个空集合必须用set()而不是{ },因为{ }是用来创建一个空字典

常用方法:(类比c++ set)

方法 描述
set.add(x) 为集合添加元素
set.remove(x) 从集合移除元素
set.pop() 随机移除元素
set.clear() 移除集合中的所有元素
len(set) 计算集合元素个数

适合用于去重、快速查找、集合关系运算

字典Dictionary

字典是一种映射类型,用{ }标识,是一个无序的键(key):值(value)的集合

3.7+默认保持插入顺序

键(key)必须使用不可变类型,且在同一个字典中,键(key)必须是唯一的

常用方法:

方法 描述
dict.keys() 返回键的视图对象
dict.values() 返回值的视图对象
dict.items() 以列表返回视图对象
dict.update(dict2) 把字典dict2的键/值对更新到dict里
dict.pop(key) 删除字典key所对应的值,返回被删除的值
dict.get(key, default=None) 返回指定键的值,如果不存在返回default设置的值

四大容器对比

特性 列表 list 元组 tuple 集合 set 字典 dict
定义方式 []list() ()tuple() {}set() {key: value}dict()
是否有序 有序 有序 无序 有序
是否可变 可变 不可变 可变 可变
元素是否可重复 允许重复 允许重复 自动去重 键(key)唯一,值(value)可重复
索引访问 支持 支持 不支持 通过 key 访问 value
适用场景 存放有序数据,可修改 固定数据,保护不被修改 去重、集合运算、快速查找 存放键值对,快速查找映射关系

文件操作

1
2
import os
from pathlib import Path

pathlib是对os的一个改进库,建议从os慢慢转变为pathlib

在MacOS和Linux系统下,路径默认使用的都是正斜杠/,在Windows系统下,正反斜杠都可以表示路径分隔符,默认的是反斜杠\,由于反斜杠本身属于转义符,这可能会导致使用反斜杠表示的路径在编码时无法被正确识别。最好就是全部用正斜杆/,避免出问题

目录

功能 os/os.path 写法 pathlib 写法
获取当前工作目录 os.getcwd() Path.cwd()
递归创建单层目录 os.makedirs(path, exist_ok=True) Path(path).mkdir(parents=True, exist_ok=True)
删除空目录 os.rmdir(path) Path(path).rmdir()

os.removedirs(path):递归删除空目录,从最深层的子目录开始删除,直到遇到一个非空目录或者抛出错误为止;

路径

功能 os / os.path 写法 pathlib 写法
拼接路径 os.path.join(path1,path2) Path(path1) / path2
文件名(不含扩展名) (很麻烦) Path(path).stem
文件名(含扩展名) os.path.basename(path) Path(path).name
取扩展名 os.path.splitext(path)[1] Path(path).suffix
取目录名 os.path.dirname(path) Path(path).parent
检查是否存在 os.path.exists(path) Path(path).exists()
删除文件 os.remove(path) Path(path).unlink()

遍历目录

os.listdir列出某个目录下的文件和文件夹名(不递归),只返回名字(不带路径),需要os.path.join拼接绝对/相对路径

Path.iterdir()结合了os.listdir + os.path.join

计算文件夹下所有csv文件的平均值:(两种写法的区别只在循环开始部分)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
from pathlib import Path
import numpy as np

def compute_average(folder):
xs = None # 存放 x 轴
ys_list = [] # 存放所有文件的第二列

for name in os.listdir(folder):
if not name.endswith(".csv"):
continue # 只处理 CSV 文件
path = os.path.join(folder, name)
"""
for path in Path(folder).iterdir():
if path.suffix != ".csv": # 扩展名检查
continue
"""
try:
# 读入两列数据
data = np.loadtxt(path, delimiter=",")
x = data[:, 0]
y = data[:, 1]

if xs is None:
xs = x
else:
# 检查 x 轴是否一致
if not np.allclose(xs, x):
raise ValueError(f"x轴不一致: {path}")

ys_list.append(y)
print(f"Loaded {path}, {len(y)} points")
except Exception as e:
print(f"Skip {path}: {e}")

if not ys_list:
raise RuntimeError("没有找到有效的 CSV 文件")
# 转为numpy
ys = np.array(ys_list)
ys_avg = np.average(ys, axis=0)

return xs, ys_avg

递归遍历

递归式的遍历,遍历指定目录及其子目录中的所有文件和目录

1
os.walk(top, topdown=True, onerror=None, followlinks=False)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
from pathlib import Path
import numpy as np

def compute_average_walk(folder):
xs = None
ys_list = []

for dirpath, _, filenames in os.walk(folder):
for name in filenames:
if not name.endswith(".csv"):
continue
path = os.path.join(dirpath, name) # 重新拼接文件路径
"""
folder = Path(folder) # 需要转一下,不然用不了rglob
for path in folder.rglob("*.csv"):
"""
try:
data = np.loadtxt(path, delimiter=",")
x = data[:, 0]
y = data[:, 1]

if xs is None:
xs = x
else:
if not np.allclose(xs, x):
raise ValueError(f"x轴不一致: {path}")

ys_list.append(y)
print(f"Loaded {path}, {len(y)} points")
except Exception as e:
print(f"Skip {path}: {e}")

if not ys_list:
raise RuntimeError("没有找到有效的 CSV 文件")

ys = np.array(ys_list)
ys_avg = np.average(ys, axis=0)

return xs, ys_avg

numpy

1
import numpy as np

创建和生成

array函数

1
np.array(object, dtype = None, copy = True, order = None, ndmin = 0)
名称 描述
object 数组或嵌套的数列list
dtype 数组元素的数据类型,可选,一般不用自己指定
copy 对象是否需要复制,可选
order 创建数组的样式,C为行方向,F为列方向,A为任意方向(默认)
ndmin 指定生成数组的最小维度

传入复数数组:

1
arr = np.array([[1+2j,2+3j],[3+4j,4+5j]], dtype= np.complex64)

arange函数

1
np.arange(start, stop, step, dtype)
参数 描述
start 起始值,默认为0
stop 终止值(不包含)
step 步长,默认为1

linspace函数

1
np.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)
参数 描述
start 序列的起始值
stop 序列的终止值,如果endpointTrue,该值包含于数列中
num 要生成的等步长的样本数量,默认为50
endpoint 该值为True时,数列中包含stop值,默认True
retstep 该值为True时,生成的数组中会显示间距,反之不显示

logspace函数

1
np.logspace(start, stop, num=50, endpoint=True, base=10.0, dtype=None)
参数 描述
start 起始值:base ** start
stop 终止值:base ** stop,如果endpointTrue,该值包含于数列中
num 要生成的等步长的样本数量,默认为50
endpoint 该值为True 时,数列中中包含stop值,默认是True
base 对数log的底数

ones/zeros函数

创建出来的 array 默认是 float 类型

1
np.ones(shape, dtype = None, order = 'C')
1
np.zeros(shape, dtype = None, order = 'C')
参数 描述
shape 数组形状
order ‘C’行数组,或者’F’用于 FORTRAN 的列数组

random函数

最重要的API,经常用于随机生成训练或测试数据,神经网路初始化等

推荐使用新的方式生成,rng 是个 Generator,可用于生成各种分布

1
rng = np.random.default_rng(42)  # Generator(PCG64) at 0x27C7981F900

rng的size都有(),np.random的不是

np.random rng
0-1 连续均匀分布 np.random.rand(3,4)
np.random.random((3,4))
rng.random((3,4))
指定上下界连续均匀分布 np.random.uniform(-1,1,(2,3)) rng.uniform(-1,1,(2,3))
指定上下界随机整数 np.random.randint(0,10,(2,3)) rng.integers(0,10,(2,3))
标准正态分布(0,1) np.random.randn(2,4) rng.standard_normal((2,4))
正态分布 np.random.normal(0,1,(3,5)) rng.normal(0,1,(3,5))

常用用的就是2个分布:均匀分布和正态(高斯)分布

数组属性

一维数组的秩为1,二维数组的秩为2

属性 说明
arr.ndim 数组的秩(rank),即数组的维度数量或轴的数量
arr.shape 数组的维度,表示数组在每个轴上的大小
arr.size 数组中元素的总个数,等于 np.shape 中各个轴上大小的乘积
arr.dtype 数组中元素的数据类型
arr.itemsize 数组中每个元素的大小,以字节为单位
arr.real 数组中每个元素的实部(如果元素类型为复数)
arr.imag 数组中每个元素的虚部(如果元素类型为复数)

统计函数

最值

1
arr.max(axis=0, keepdims=True) # min同理
参数 描述
axis 0为列,1为行,默认为全部
keepdims 是否保持原有维度,默认False

这个需要特别注意下,很多深度学习模型中都需要保持原有的维度进行后续计算

在统计函数这块axis,keepdims用法都是一样的

另一种写法(少用):

1
2
np.amax(arr)
np.amin(arr)

分位数

1
np.median(arr) # 中位数
1
np.quantile(arr, q, axis) 
1
2
3
4
a = np.array([[10, 7, 4], [3, 2, 1]])
# 75% 的分位数,就是 a 里排序之后的中位数
print (np.percentile(a, 75))
# 6.25 75%*(6-1)=3.75 锁定索引3和4两个数 4+0.75*(7-4)=6.25

平均求和标准差

使用最多的是「平均值」

1
2
3
4
np.average(arr) # 平均值
np.sum(arr) # 和
np.std(arr) # 标准差
np.var(arr) # 方差
1
2
np.cumsum(arr)  # 累加
np.cumprod(arr) # 累乘

元素出现次数

1
np.bincount(arr)

返回不同元素出现的次数

数组操作

这小节里面的API使用非常高频,尤其是扩展1维度的 expand_dims 和去除1维度的 squeeze,在很多神经网络架构常见

修改数组形状

reshape会生成一个新的array,但resize不会

函数 描述
reshape() 原数组形状不变
resize() 会改变原数组形状
ravel() 将多维数组展平
1
2
arr.reshape(2,2,3) 
arr1 = arr.reshape(4,-1) # 可以用-1让编译器自动计算,resize不行

reshape元素数量必须与原array一致

1
2
3
4
# arr1.resize(3,2)       # 报错,数组内元素数量多
arr1copy = np.copy(arr1) # 可以copy再resize来截断
arr1copy.resize(3,2)
arr1.shape # arr1的数组shape不变 (4,3)
1
2
3
4
# 将 refcheck 设为 False ,可以让resize超出部分设为0
# 注意,这种方法不能对切片出来的数据使用,会提示报错
arr2 = np.copy(arr1)
arr2.resize(4,4,refcheck=False) # 多的部分全部为0

修改数组维度

注意:无论是扩展还是缩减,无论是扩展还是缩减,多或少的 shape 都是 1

扩充维度不能跳跃

1
2
3
4
5
6
rng = np.random.default_rng(42)
arr = rng.integers(1,100,(3,4))
expanded = np.expand_dims(arr, axis=(1, 3, 4))
# expanded.shape → (3,1,4,1,1)
# axis 先在1插入,变成(3,1,4) 再在axis=3插入,变成(3,1,4,1) 最后再插
# expanded = np.expand_dims(arr, axis=(1, 3, 8)) # 报错

squeeze时如果指定维度,则该维度shape必须是 1

1
2
3
4
# np.squeeze(expanded, axis=0) # 报错
np.squeeze(expanded, axis=1) # (3,4,1,1)
# 去除所有维度为 1 的
np.squeeze(expanded) # (3,4)

反序

如果对一个字符串或数组进行反序,一般会利用reversed,或者利用list的索引,这就是numpy中array的反序方式

1
2
3
4
arr
# [[ 9 77 65 44]
# [43 86 9 70]
# [20 10 53 97]]
1
2
3
4
arr[::-1] # 默认行反序
# [[20 10 53 97]
# [43 86 9 70]
# [ 9 77 65 44]]
1
2
3
4
5
# 按行反转(上下翻转)
arr[::-1, :]
# [[20 10 53 97]
# [43 86 9 70]
# [ 9 77 65 44]]
1
2
3
4
5
# 按列反转(左右翻转)
arr[:, ::-1]
# [[44 65 77 9]
# [70 9 86 43]
# [97 53 10 20]]
1
2
3
4
5
# 同时行列反转
arr[::-1, ::-1]
# [[97 53 10 20]
# [70 9 86 43]
# [44 65 77 9]]

转置

通俗理解就是把数组放倒,shape反转,行变成列,列成为行

建议二维矩阵用 arr.T(会快很多),超过二维的张量可以用 np.transpose,会更加灵活些

注意:一维数组转置还是自己,不会反转

1
2
3
4
5
6
arr.T
# [[ 9 43 20]
# [77 86 10]
# [65 9 53]
# [44 70 97]]
arr.reshape(1,1,3,4).T.shape # 输出 (4,3,1,1)
1
2
3
# np.transpose可以指定 axes,不指定时和T一样
np.transpose(arr.reshape(1, 1, 3, 4), axes=(0, 2, 1, 3)).shape
# 输出 (1,3,1,4)

拼接

本小节严格来说只有两个API:np.concatenatenp.stack

前者是拼接,后者是堆叠(会增加一个维度),都可以指定维度

vstackhstack虽然看起来是stack,但他俩本质还是concatenate

vstack等价于 axis=0,hstack等价于axis=1,建议只用concatenate

1
2
3
4
5
rng = np.random.default_rng(42)
arr1 = rng.random((2,3))
arr2 = rng.random((2,3))
np.concatenate(arr1, arr2) # 默认沿axis=0(列)连接
np.concatenate((arr1, arr2), axis=1) # 沿行连接
1
2
3
4
5
6
7
# 堆叠,默认根据 axis=0 进行
np.stack((arr1, arr2)) # 增加维度,shape(2,2,3)
# [[[0.77395605 0.43887844 0.85859792]
# [0.69736803 0.09417735 0.97562235]]

# [[0.7611397 0.78606431 0.12811363]
# [0.45038594 0.37079802 0.92676499]]]
1
2
3
4
5
6
7
8
9
# 堆叠,根据 axis=2
np.stack((arr1, arr2), axis=2) # shape为(2,3,2)
# [[[0.77395605 0.7611397 ]
# [0.43887844 0.78606431]
# [0.85859792 0.12811363]]

# [[0.69736803 0.45038594]
# [0.09417735 0.37079802]
# [0.97562235 0.92676499]]]

切片和索引

切片和索引是通过对已有 array 进行操作而得到想要的「部分」元素的行为过程

把处理按维度分开,不处理的维度统一用:替代,也有用...表示的,但是:更多

索引支持负数,即从后往前索引

1
2
arr[:3, 1:3]       # 0-3行,1-3列
arr[1:4:2, 0:3:2] # 也可以start:stop:step来多选

筛选和过滤

主要包括以下内容:

  • 条件筛选
  • 提取(按条件)
  • 抽样(按分布)
  • 最大最小 index(特殊值)

这几个内容都很重要,使用的也非常高频,条件筛选经常用于 Mask 或异常值处理,提取则常用于结果过滤,抽样常用在数据生成(比如负样本抽样),最大最小 index 则常见于机器学习模型预测结果判定中(根据最大概率所在的 index 决定结果属于哪一类)

条件筛选

核心 API 是 np.where,返回输入数组中满足给定条件的元素的索引(元组)

需要注意的是:where分别返回各维度的index,赋值的是「不满足」条件的,类似三元表达式

1
2
np.where(arr>50, arr, -1)  # 将<=50的赋值为-1
# arr>50 ? arr : -1

提取

np.extract() 函数根据某个条件从数组中抽取元素,返回满条件的元素

1
np.extract(condition, arr)

提取和唯一值返回的都是一维向量

1
np.unique(arr) # 提取唯一值,也是一种提取

抽样

1
np.random.choice(a, size=None, replace=True, p=None)
参数 描述
a 如果是整数表示从np.arange(a) 里抽取元素
如果是一维数组直接从数组里抽取
size 抽取结果的形状(默认是返回一个数),也可以是元组
replace True(默认)有放回抽样,元素可重复
False则不放回,元素不重复
p 每个元素被选中的概率,和必须为1,默认均匀分布

如果replace=False,那么 size 不能大于元素总数

1
np.random.choice(1000,50,replace=False)

最值索引

主要是np.argmax/argmin这两个函数

1
2
np.argmax(arr, axis=None)
np.argmin(arr, axis=None)

None(默认)把数组展平成一维,再找最大/最小值的位置,axis=0沿列方向,axis=1沿行方向

np.argsort() 函数返回的是数组值从小到大的索引值

1
np.argsort(arr, axis=-1,kind=None)

axis默认-1,即最后一个轴,如果设置为None则会先展平为一维再排序

kind是排序算法,可选

种类 速度 最坏情况 工作空间 稳定性
'quicksort'(快速排序)(默认) 1 O(n^2) 0
'mergesort'(归并排序) 2 O(n*log(n)) ~n/2
'heapsort'(堆排序) 3 O(n*log(n)) 0

矩阵运算

算术

所有的算术函数均可直接运用于array,+-*/四则运算,**平方以及开方

功能 numpy模块
绝对值 np.abs(x) / np.fabs(x)
四舍五入 np.round(x, n)
取整 np.ceil(x)(向上)/np.floor(x)(向下)
对数 np.log(x) np.log10(x) np.log2(x)
取余 np.mod(x, y) y可以是数组,不同被除数
1
2
np.minimum(arr, num) # 超过num的都置为num
np.maximum(arr, num) # 小于num的都置于num

广播

广播(Broadcast)是numpy对不同形状(shape)的数组进行数值计算的方式

当运算中的2个数组的形状不同时,将自动触发广播机制

广播规则:

  • 如果两个数组的维度不相同,那么小维度数组的形状将会在最左边补1
  • 如果两个数组的形状在任何一个维度上都不匹配,那么数组的形状会沿着维度为1的维度拓展以匹配另外一个数组形状
  • 如果两个数组的形状在任何一个维度上都不匹配并且没有任何一个维度等于1,那么会引发异常
1
2
3
4
5
a = np.array([[1, 2, 3]])   # shape (1,3) → shape(2,3)
b = np.array([[10], [20]]) # shape (2,1) → shape(2,3)
print(a + b)
# [[11 12 13]
# [21 22 23]]

不能广播的情况:

1
2
3
a = np.array([1, 2, 3])     # shape (3,)
b = np.array([1, 2]) # shape (2,)
print(a + b) # 报错,因为维度匹配了但既不相等也不为1 → 无法广播

矩阵

dotmatmul 在高维度时表现不同

1
np.dot(a, b)
1
2
np.matmul(a, b)   # 等价于 a @ b
# 写矩阵运算时推荐用 `@`,代码更直观
情况 np.dot np.matmul(或 @)
1D·1D 向量内积(标量) 向量内积(标量)
2D·2D 矩阵乘法 矩阵乘法
2D·1D 矩阵 × 向量 → 1D 矩阵 × 向量 → 1D
1D·2D 向量 × 矩阵 → 1D 向量 × 矩阵 → 1D
高维数组 只在最后两维做点积 支持批量矩阵乘法

Hadamard乘积(逐元素相乘):直接使用*

常用线性代数方法:

方法 描述
np.linalg.det(arr) 计算行列式
np.linalg.inv(arr) 求逆矩阵(前提行列式>0)
np.linalg.eig(arr) 求特征值和特征向量
np.linalg.solve(A,b) 求方程Ax=b

从文件读取

加载常用数据格式 保存数据
np.loadtxt() np.savetxt()
np.fromstring() np.save(), np.savez(), np.savez_compressed()
1
np.loadtxt(fname, dtype=None, delimiter=",", comments="#", skiprows=1)
参数 描述
fname 文件的路径
delimiter 指定数据在文件中的分隔符,在CSV文件中通常是逗号
skiprows 指示 loadtxt 函数跳过文件前n行
comment 如果行的开头为"#"则跳过
1
np.fromstring(s, dtype, count=-1, sep=',')  # 把字符串转换成一维数组
参数 描述
s 输入的字符串
delimiter 指定数据在文件中的分隔符,在CSV文件中通常是逗号
count 需要读取的元素数量,默认为-1,表示读取字符串中的所有元素
sep 字符串中元素之间的分隔符,在CSV文件中通常是逗号

保存为文件

1
np.savetxt(fname, arr, fmt='%s', delimiter=',', newline='\n', header='', footer='', comments='# ')
参数 描述
fname 文件的路径
arr 要存储的阵列数据
fmt 要存储的数据格式,默认%.18e(科学记数法)
delimiter 加载分隔符,默认是空格
newline 行分隔符,默认换行符
header 开头字符串(存储为csv文件时可以生成标题)
footer 结尾字符串
comments 文中的注释
1
2
3
4
5
row_string = "20131, 10, 67, 20132, 11, 88, 20133, 12, 98, 20134, 8, 100, 20135, 9, 75, 20136, 12, 78"
data = np.fromstring(row_string, sep=",")
data = data.reshape(6, 3)
# 保存数据
np.savetxt("save_data.csv", data, delimiter=",", fmt='%s')

np.save()保存一个array → .npy

np.savez()保存多个array(同np.savez_compressed()) → .npz

1
np.savez(file,kwd1=arr1, kwd2=arr2)
1
2
3
4
5
6
7
train_data = np.array([1, 2, 3])
test_data = np.array([11, 22, 33])
np.savez("save_data_02.npz", train=train_data, test=test_data)
# np.savez_compressed("save_data.npz", train=train_data, test=test_data)
npz_data = np.load("save_data_02.npz")
print("train:", npz_data["train"])
print("test:", npz_data["test"])

matplotlib

1
2
import matplotlib.pyplot as plt
import numpy as np

其用法和matlab极其相似

一个完整的matplotlib图像通常会包括以下四个部分:

  • Figure:用来容纳所有绘图元素,可以包含多个Axes,可以设置画布大小、分辨率等
  • Axes:实际绘制数据的区域,一个Figure可以包含多个Axes,包含坐标轴、标题、标签等
  • Axis:Axes的下属层级,用于处理坐标轴、刻度、标签范围
  • Artist:图表中的元素:线条、点、文字、图例等

matplotlib提供了两种最常用的绘图接口:

  1. 显式创建figure和axes,直接操作Figure和Axes对象,更灵活可控,也被称为OO模式

  2. 依赖pyplot自动创建figure和axes,写法和matlab完全一致

    1
    2
    3
    4
    5
    6
    plt.plot()
    plt.title()
    plt.xlabel()
    plt.ylabel()
    plt.legend([])
    plt.suptitle() # 加大标题
常见属性 设置
颜色color r g b c m y k w
线型linestyle '-', '--', '-.', ''(只画点)
线宽linewidth 一般2
标记marker 'o', 's', 'D', '*', '+', 'x'
透明度alpha 看情况
文本字号、粗细、字体 fontsizefontweightfamily

字体族主要有:

  • serif:Times New Roman,主要论文
  • sans-serif:Arial
  • monospace:Courier New,主要代码
  • cursive(手写体/花体)
  • fantasy(装饰体)

解决中文乱码问题:

1
2
3
4
import matplotlib as mpl
import matplotlib.pyplot as plt
plt.rcParams['font.sans-serif'] = ['SimHei', 'Microsoft YaHei', 'KaiTi', 'SimSun'] # 设置常用中文字体
plt.rcParams['axes.unicode_minus'] = False # 正常显示负号

OO模式

进入OO模式的方法:

1
fig, ax = plt.subplots()  # OO 接口,创建 Figure 和 Axes

Axes(子图)常用操作:

1
2
3
4
ax.set_title("Title")  # 设置标题
ax.set_xlabel("X") # x轴标签
ax.set_ylabel("Y") # y轴标签
ax.grid(True) # 显示网格

Axis(坐标轴)常用操作:

1
2
3
4
ax.set_xlim(low, high)  # x 轴范围
ax.set_ylim(low, high) # y 轴范围
ax.set_xticks([0, 2, 4, 6, 8, 10]) # 指定 x 轴刻度
ax.set_yticks([-1, 0, 1]) # 指定 y 轴刻度

Artist有两种类型:primitivescontainers

primitive是最基础的绘图对象,用于具体显示数据和形状

Artist 类型 描述
Line2D 折线、曲线、散点图
Text 文字、标签、标题
Patch 各种二维形状:矩形、圆、多边形、柱状、扇形(饼图)等
Collection 一组类似元素,如散点图点集合、柱状图柱子集合等

Line2D

线绘制

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
x = np.linspace(0, 2*np.pi, 100)
y1= np.sin(x)
y2 = np.cos(x)
fig, ax = plt.subplots() # fig=画布, ax=子图坐标系
ax.plot(
x, y1,
color='blue', linestyle='-', linewidth=2,
marker='o', markersize=4,
label='sin(x)' # 图例标签
)
ax.plot(
x, y2,
color='red', linestyle='--', linewidth=2,
marker='s', markersize=4,
label='cos(x)'
)
ax.set_title("Sine and Cosine Functions", fontsize=14)
ax.set_xlabel("x", fontsize=12)
ax.set_ylabel("y", fontsize=12)
ax.grid(True)
ax.legend(loc='upper right') # 图例放在右上角
plt.show()

多子图subplot

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
x = np.linspace(0, 10, 100)
y1 = np.sin(x)
y2 = np.cos(x)
# 创建两个子图坐标系
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))

ax1.plot(
x, y1,
color='blue', linestyle='-', linewidth=2,
marker='o', markersize=4,
label='sin(x)'
)
ax1.set_title("Sine Function")
ax1.set_xlabel("x")
ax1.set_ylabel("y")
ax1.grid(True)
ax1.legend(loc='upper right')

ax2.plot(
x, y2,
color='red', linestyle='--', linewidth=2,
marker='s', markersize=4,
label='cos(x)'
)
ax2.set_title("Cosine Function")
ax2.set_xlabel("x")
ax2.set_ylabel("y")
ax2.grid(True)
ax2.legend(loc='upper right')

plt.tight_layout() # 自动调整布局,避免标题/标签重叠
plt.show()

散点图scatter

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
rng = np.random.default_rng(42)
x = rng.random(50) * 10 # 50个[0,10)随机数
y = 2 * x + rng.standard_normal(50) * 2 # 带点噪声的直线

fig, ax = plt.subplots()
# 绘制散点图
sc = ax.scatter(
x, y,
s=80, # 点大小,也可以根据值
c=x, # 按x的值映射颜色
cmap="viridis", # 颜色映射
alpha=0.8, # 透明度
edgecolor="black" # 边框颜色
)
# 添加颜色条
fig.colorbar(sc, ax=ax, label="X value")
# 设置标题和坐标轴
ax.set_title("Scatter Plot Example")
ax.set_xlabel("X")
ax.set_ylabel("Y")
plt.show()

Patch

在实际中最常见的矩形图是hist直方图bar条形图

二者的区别:

  • hist:输入一堆数据,Matplotlib 自动分箱并统计

  • bar:自己提供类别 + 数值,Matplotlib 直接画出来

直方图hist

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import matplotlib.pyplot as plt
rng = np.random.RandomState(42)

# 生成数据(正态分布)
data = rng.normal(loc=0, scale=1, size=1000)

fig, ax = plt.subplots()
# 返回每个区间里的频数,分箱的边界以及每个柱子对应的矩形对象
counts, bin_edges, patches = ax.hist(
data,
bins=40, # 分箱数量
# color='skyblue', # 填充色,如果不后续填充
edgecolor='black', # 边框色
alpha=0.7, # 透明度
density=False, # 是否显示密度(False=频数)
label='Data'
)

# 遍历每个柱子,根据区间位置设置颜色
for count, edge_left, edge_right, patch in zip(counts, bin_edges[:-1], bin_edges[1:], patches):
# 区间中心点
center = 0.5 * (edge_left + edge_right)
if center >= 0: # 如果区间中心 >= 0 → 染红
patch.set_facecolor('red')
else: # 如果区间中心 < 0 → 染蓝
patch.set_facecolor('blue')

ax.set_title("Histogram Example")
ax.set_xlabel("Value")
ax.set_ylabel("Frequency")
ax.legend()
plt.show()

条形图bar

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import matplotlib.pyplot as plt
import numpy as np

# 类别与数值
categories = ["A", "B", "C", "D"]
values = [23, 17, 35, 29]

fig, ax = plt.subplots()
bars = ax.bar(
x=np.arange(len(categories)), # 类别位置 (0,1,2,3)
height=values, # 对应的数值
color="orangered", edgecolor="black"
)
# 设置刻度标签
ax.set_xticks(np.arange(len(categories)))
ax.set_xticklabels(categories)
# 添加标题与标签
ax.set_title("Bar Chart Example")
ax.set_xlabel("Category")
ax.set_ylabel("Value")
# 在柱子上标注数值
for bar in bars:
height = bar.get_height()
ax.text(bar.get_x() + bar.get_width()/2, height,
f"{height}", ha='center', va='bottom')
plt.show()

饼图pie

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import matplotlib.pyplot as plt

sizes = [25, 30, 20, 25]
labels = ["A", "B", "C", "D"]
colors = ["skyblue", "orange", "lightgreen", "pink"]
explode = [0, 0.1, 0, 0] # 突出 B

fig, ax = plt.subplots(figsize=(6, 6))
# 画饼图
ax.pie(
x=sizes, # 每个扇区的数值(自动归一化为百分比)
labels=labels, # 扇区对应的类别标签
colors=colors, # 每个扇区的颜色
explode=explode, # 控制扇区“突出”效果
autopct="%1.1f%%", # 显示百分比
startangle=90, # 从90度开始绘制
shadow=True # 添加阴影
)
ax.set_title("Pie Chart Example")
plt.show()