Skip to content

自動微分 | DIY 實現自己的 PyTorch

在機器學習問題中,模型訓練的核心是梯度下降:

  1. 計算損失函數(Loss Function)
  2. 計算損失對參數的導數(Gradient)
  3. 根據梯度更新參數
  4. 重復直到收斂

我們需要找到一個高效的梯度計算方法。

幾種梯度計算方法

梯度計算方法包括:數值微分、符號微分、自動微分

1. 數值微分

數值微分的基本思想,是使用一個非常小的 hh 去近似計算。

dfdxf(x+h)f(x)h\frac{df}{dx} \approx \frac{f(x+h) - f(x)}{h}

但是,這種計算方法存在舍入誤差,而且計算復雜度高。

如果我們有 nn 個參數,需要執行 O(n)O(n) 次函數計算。在參數量達百萬級的神經網絡中,這種計算成本是不可接受的。

2. 符號微分

符號微分的思路是通過鏈式法則得到一個完整的解析表達式。在 Python 中,我們可以使用 sympy 去實現。

但是,一旦函數變得復雜,表達式長度增長極快,同一個子表達式在求導過程中會被多次重復計算。這種方法都理論上很優雅,在工程上卻不可行。

自動微分

自動微分將復雜函數拆解成一系列簡單的「基本運算」,並在計算過程中同步或反向計算導數。

自動微分通過記錄這些基礎運算的執行軌跡(即構建「計算圖」),並在計算過程中系統化地應用微積分中的鏈式法則##,從而極其精確且高效地計算出復雜函數對各個變量的導數。

要實現自動微分,首先要將復雜的算式轉化為「計算圖」。在這張有向無環圖(DAG)中:葉子節點代表輸入變量或模型參數。內部節點代表基本運算(如加法、乘法、sinsin 等)。邊代表數據流向。

我們可以在計算圖上進行反向傳播,這一過程分為兩個階段:

  1. 前向階段: 從輸入到輸出正常計算,但會保存所有的中間變量(計算圖)。
  2. 反向階段: 從最終的標量輸出(如 Loss 損失值)開始,從輸出向輸入反向遍歷計算圖。它通過鏈式法則,將輸出節點對當前節點的導數(稱為伴隨值 Adjoint,用 vˉi\bar{v}_i 表示)一步步往回傳遞。

工程實現上,我們可以使用一個簡單的拓撲排序,按照拓撲順序,從輸出節點反向計算梯度,完成梯度下降。

工程實現

筆者自己完成了一自動微分的簡單實現,包括了一些常用的神經網絡 Loss Function 和 Optimizer,在 GitHub 上開源:https://github.com/aeilot/simplegrad

我實現的是類似 PyTorch 的動態圖。

下面簡單介紹一些核心模塊。

Tensor

1
2
3
4
5
6
7
8
9
10
11
12
class Tensor:
def __init__(self, data, requires_grad: bool = False):
if not isinstance(data, np.ndarray):
data = np.array(data, dtype=np.float32)

self.data = data
self.requires_grad = requires_grad
self.grad = None # 梯度
self.parents = [] # 父節點
self.grad_fn = None # 創造了這個 `Tensor` 的數學操作

# 以下省略

Tensor 不僅僅是簡單的張量,還存儲了梯度、父節點等信息。通過運算符重載,我們可以隱式地構建計算圖,實現反向傳播。

Ops / Function

對於數學操作,我們需要定義前向運算和反向運算。

1
2
3
4
5
6
7
8
9
10

class Function:
def __init__(self):
pass

def forward(self):
raise NotImplementedError

def backward(self, grad_output):
raise NotImplementedError

定義了接口,我們就可以實現各種各樣的算子。

以矩陣乘法為例。

眾所周知,矩陣乘法求導有如下公式:

Y=ABY = AB

dA=LA=LYBTdA = \frac{\partial L}{\partial A} = \frac{\partial L}{\partial Y} B^T

dB=LB=ATLYdB = \frac{\partial L}{\partial B} = A^T \frac{\partial L}{\partial Y}

我們很容易寫出代碼:

1
2
3
4
5
6
7
8
9
10
class MatMul(Function):
def __init__(self, a, b):
super().__init__()
self.a = a
self.b = b

def backward(self, grad):
grad_a = grad @ self.b.data.T
grad_b = self.a.data.T @ grad
return [grad_a, grad_b]

然後進行運算符重載:

1
2
3
4
5
6
7
8
9
10
def __matmul__(self, other):
other = ensure_tensor(other) # 防止類型異常
out = Tensor(
self.data @ other.data,
requires_grad=self.requires_grad or other.requires_grad,
)
if out.requires_grad and GradMode.enabled:
out.grad_fn = ops.MatMul(self, other)
out.parents = [self, other]
return out

運算符重載的時候,我們儲存了父節點,用戶不需要自己手動建圖,我們自動構建了計算圖。

用這種方法,我實現了加法、減法、Hadamard 積、對數、Softmax、求和、平均數等運算,可以在 GitHub 倉庫 查看(不要白嫖 給個 star)

GradMode

細心的讀者可能註意到了,在前面 MatMul 的重載代碼中,有一個判斷條件 if out.requires_grad and GradMode.enabled:

這對應了 PyTorch 中的 torch.no_grad()。在模型訓練好後進行推理(Inference)或評估時,我們不需要更新參數,也就不需要計算梯度。如果此時框架還在後臺默默「建圖」(保存父節點和 grad_fn),會白白浪費大量內存。

我們可以用一個簡單的上下文管理器(Context Manager)來實現這個開關:

1
2
3
4
5
6
7
8
9
10
11
class GradMode:
enabled = True

@contextmanager
def no_grad():
prev = GradMode.enabled
GradMode.enabled = False
try:
yield
finally:
GradMode.enabled = prev

使用時非常優雅:

1
2
3
with no_grad():
# 這裏的運算不會構建計算圖,節省內存
y_pred = model(x)

backward

反向傳播的任務是從輸出節點出發,一路沿著圖的邊,把梯度精準地傳回給所有的輸入節點。

我們基於 DFS + 拓撲排序實現,保證在計算某個節點的梯度之前,依賴於它的所有「下遊」節點的梯度都已經被完全計算出來。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def backward(output):
# 1. 拓撲排序:確保梯度的計算順序正確
def toposort(tensor):
visited = set()
order = []

def dfs(node):
node_id = id(node)
if node_id in visited:
return
visited.add(node_id)
for pa in node.parents:
dfs(pa)
order.append(node)

dfs(tensor)
order.reverse() # 翻轉得到從輸出到輸入的遍歷順序
return order

order = toposort(output)

# 2. 梯度初始化:最終輸出節點(如 Loss)的梯度設為 1
output.grad = np.ones_like(output.data)

# 3. 鏈式法則與梯度傳播
for node in order:
# 如果是葉子節點(如用戶直接輸入的變量),則跳過
if node.grad_fn is None:
continue

# 調用計算節點對應的局部反向函數
grads = node.grad_fn.backward(node.grad)

# 遍歷當前節點的所有父節點,將梯度回傳
for p, g in zip(node.parents, grads):
if p.grad is None:
p.grad = unbroadcast(g, p.shape)
else:
# 關鍵點:梯度累加
p.grad += unbroadcast(g, p.shape)

當一個變量在計算圖中被多次使用時(例如 y=xxy = x * x),它的梯度必須是各個分支回傳梯度的總和。

這裏註意一點細節,有的時候 numpy 運算會出現 broadcasting,導致父節點和子節點 shape 不符。

在 NumPy 中,如果我們把一個形狀為 (3, 1) 的張量與一個形狀為 (1, 3) 的張量相加,NumPy 會極其聰明地將它們自動擴展(Broadcast)成 (3, 3) 並計算出結果。然而,在反向傳播時,上遊傳下來的梯度形狀是 (3, 3),但我們原本的節點形狀是 (3, 1),直接將梯度傳回去會導致形狀不匹配而報錯。

因此,在把梯度賦值給父節點之前,我們必須進行一次反廣播(Unbroadcasting),把多出來的維度通過求和操作「擠壓」回原本的形狀:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
def unbroadcast(grad, target_shape):
# 如果形狀已經完美匹配,直接返回
if grad.shape == target_shape:
return grad

# 情況一:處理由於廣播而新增的前置維度(例如標量與矩陣運算)
ndims_added = grad.ndim - len(target_shape)
for _ in range(ndims_added):
grad = grad.sum(axis=0)

# 情況二:處理被擴展的維度(原本大小為 1 的維度被擴展成了 N)
for i, dim in enumerate(target_shape):
if dim == 1:
# 沿著被擴展的維度進行求和降維,並保持維度結構
grad = grad.sum(axis=i, keepdims=True)

return grad

unbroadcast 函數的邏輯非常清晰,它分兩步解決了形狀還原的問題:

  1. 首先,如果正向計算時由於維度不對齊導致新增了維度,我們將這些多余的維度全部按軸求和並抹除。

  2. 其次,逐一對比目標形狀。如果發現原本該維度的大小為 1(意味著它被 NumPy 復製擴展了),我們就將傳回來的梯度沿著這個維度進行聚合求和,並使用 keepdims=True 維持其 (..., 1, ...) 的骨架。

優化器與梯度清零

我實現了一個帶動量的隨機梯度下降優化器:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import numpy as np

class SGD:
def __init__(self, parameters, lr=0.01, momentum=0.0):
self.parameters = list(parameters)
self.lr = lr
self.momentum = momentum

# 為每個參數初始化一個全零的速度矩陣
self.velocities = {id(p): np.zeros_like(p.data) for p in self.parameters}

def step(self):
for p in self.parameters:
if p.grad is None:
continue

pid = id(p)
grad = p.grad

if self.momentum != 0.0:
# 核心動量公式:v = momentum * v + grad
self.velocities[pid] = self.momentum * self.velocities[pid] + grad
update_dir = self.velocities[pid]
else:
update_dir = grad

# 更新參數
p.data -= self.lr * update_dir

def zero_grad(self):
for p in self.parameters:
# 這裏要求我們的 Tensor 類中實現一個簡單的 zero_grad 方法
# def zero_grad(self): self.grad = None
p.zero_grad()

在我們的 backward 實現中,有一行代碼是 p.grad += unbroadcast(...)。我們使用的是梯度累加(+=),而不是直接賦值(=)。這意味著,如果不手動清空梯度,第二輪叠代的梯度就會疊加上第一輪的梯度,導致參數更新完全錯誤!因此,每次叠代前必須調用 optimizer.zero_grad()

其他

我在 SimpleGrad 中還實現了如 Adam 優化器等其他模塊,大家可以自行查看。

About this Post

This post is written by Louis C Deng, licensed under CC BY-NC 4.0.

#數學 #PyTorch #機器學習