Diffusion models là nhóm mô hình sinh dựa trên ý tưởng quá trình ngẫu nhiên khuyếch tán của vật lí. Thả giọt mực vào một cốc nước, các phân tử sẽ lan tỏa, cốc nước ban đầu sớm nhộm màu đồng đều bởi các phân tử có mặt khắp nơi. Nghệ thuật là đây, nếu cầm trong tay một cốc màu đã tan. Liệu có cách nào quay ngược thời gian lại phân bố trạng thái ban đầu của một trạng thái giọt mực ban đầu.
Nếu phân bố xác suất có thể ví von như bản đồ niềm tin đặc tả bởi các con số, liệu ta có thể xây dựng một chuổi phác thảo các bản đồ được sắp xếp trình tự một cách hợp lí để quay ngược thời gian từ thời điểm kết thúc đến thời điểm bắt đầu?
Một trong những hướng tiếp cận mô hình sinh (generative model) là định nghĩa một quá trình ngẫu nhiên mà bạn có thể hiểu quá trình thuận nghịch của nó. Dữ liệu thực tế (ví dụ: ảnh số) có phân bố rất phức tạp, khó lòng có thể mô hình hóa lấy mẫu trực tiếp, thay vì vậy ta có thể chuyển nhiều bài toán về dạng đơn giản hơn, bằng cách chuyển phân bố phức tạp thành phân bố đơn giản và lấy mẫu trên quá trình dịch ngược.
1. Lối mẫu ngẫu nhiên, đi dể khó về
Mới hôm nao, kí ức ngày tựu trường trong veo, vậy mà giờ đây chỉ còn là bức tranh nhạt nhòa. Mỗi ngày lại quên đi một chút ít, bức tranh ngày ấy mờ dần theo tháng năm. Gắng nhớ mãi cũng không thể nào nhớ lại được “một mạch” thời điểm ấy bắt đầu như thế nào.
Nếu là một kẻ du hành thời gian, liệu ta có thể xâu chuỗi các mắc xích thời gian. Xây dựng quá trình tiến khiến tất cả sẽ tan biến chỉ còn dĩ vãng. Xây dựng quá trình lùi cứ đi từng bước một và tưởng tượng ngày ấy thế nào. Hợp lí hóa cả cả hai quá trình, liệu ta có thể quay ngược lại nơi bắt đầu?
Lấy ý tưởng tương tự, ta sẽ định nghĩa một quá trình tăng nhiễu dần đến khi phân bố gốc trở thành phân bố chuẩn tắc nhiều chiều. Tại mỗi bước đi lùi, ta sẽ ước tính tham số phân bố dịch ngược với T bước nhảy thời gian dài đăng đẳng. Và hợp lí hóa quá trình trên bởi phân bố hội của quá trình tiến & lùi qua ELBO.
Mô phỏng thử các lối mẫu ngẫu nhiên (sample path của stochastic process) của quá trình khuyếch tán, bạn sẽ thấy có vô vàn kết quả để đến một vị trí đích. Nếu chọn một vị trí đích xt (điểm dữ liệu x tại thời điểm t ) một câu hỏi được đặt ra, nếu tất cả các điểm đều có thể định lượng cơ hội xảy ra bằng một con số, cơ hội sẽ được phân bố như thế nào, nơi nào nhiều hơn, nơi nào ít hơn?
Một cách đơn giản nhất để đi ngược thời gian là sẽ di chuyển ngược hàm mật độ khi biết vị trí đang đứng hiện tại ở t, hay nói một cách khác cố gắng mô hình hóa phân bố p(xt−1∣xt) được tham số hóa.
Nhưng nếu mô phỏng đủ nhiều, bạn sẽ thấy phân bố này tương đối phức tạp, và cũng chẳng có thông tin nào về nó cả… đấy là cái khó khi… đó là nếu, đó là mơ, quá trình nghe vẫn rất kiêu, vẫn rất khó để làm điều như vậy. Nhưng hãy cùng nhau làm mọi thứ đơn giản hơn nào!
Giải thích hình vẽ:
Mỗi khung ảnh chuyển động là một lối mẫu ngẫu nhiên có thể xảy ra (đùa tí cho vui: hãy tưởng tượng bạn là Dr Strange và thấy được tất cả các kết quả của các thế giới khác nhau). Nếu để ý, các điểm được đánh dấu có thể di chuyển bất kì đâu.
Nếu như ta ở điểm màu cam xt, vì các điểm có thể di chuyển bất kì đâu, để đi lùi quá khứ ta có thể xây dựng bản đồ niềm tin a.k.a phân bố xác suất dịch ngược pθ(xt−1∣xt) , hmmm nhưng mà xây dựng như thế nào nhỉ? Một câu hỏi thú vị là nơi một hành trình bắt đầu… đi thôi nào…
2. Denoise Diffusion Probabilistic Models:
Sỡ dĩ gọi là Denoise Diffusion Probabilistic Models là vì:
Quá trình lùi (Reverse Process) tại đây ta chỉ có điểm dữ liệu trông giống như nhiễu :D ví von quá trình đi dịch chuyển lịch sử giống như việc khữ nhiễu bức ảnh cũ kĩ đến khi dữ liệu trở về nguyên bản.
Trong bài viết này dùng kí hiệu N(.) để kí hiệu cho phân bố chuẩn nhiều chiều.
Việc lựa chọn phân bố xác suất trong mô hình phụ thuộc nhiều tiêu chí: phù hợp mô tả quá trình xác suất, dể tính toán. Phân bố chuẩn có nhiều tính chất đẹp & tiện lợi cho việc tính toán nên được sử dụng nhiều trong bài toán này.
2.1 Quá trình tiến (Forward Process)
Quá trình tiến tại mỗi bước nhảy thời gian đơn giản là dần dần cộng thêm nhiễu vào. Hay nói cách khác:
xt=xt−1+ξt−1
Nếu chọn nhiễu được sinh ra từ phân bố chuẩn ξt−1∼N(μξt−1,Σξt−1) . Tinh gọn công thức lại thành một phân bố duy nhất, có thể viết:
Ký hiệu q là hàm mật độ của quá trình tiến, ta có:
q(xt∣xt−1)=N(xt∣xt−1+μξt−1,Σξt−1)
Trong bài báo Denoising Diffusion Probabilistic Models hàm mật độ được chọn:
q(xt∣xt−1)=N(xt∣1−βtxt−1,βtI)
Với:
t là bước nhảy thời gian từ 1 đến T.
βt là phương sai lịch trình (schedule variance).
βt và T được gán cứng. Không thay đổi trong quá trình huấn luyện mô hình.
Để hiểu lựa chọn trên, chúng ta sẽ cùng nhau phân tích quá trình này.
Dẫu chúng ta học về quá trình ngẫu nhiên, ấy vậy hầu như không có sự chọn lựa ngẫu nhiên trong lúc xây dựng mô hình :D mọi thứ đã được toan tính kỹ lưỡng cho một tương lai xa. Công thức trên nếu được thiết lập một cách đúng đắn, tại cuối quá trình ta sẽ trở về phân bố chuẩn tắc nhiều chiều.
Khi đó phân bộ hội (joint distribution) của một lối mẫu ngẫu nhiên cùng nhau xuất hiện x1,x2,…xt được tính như sau:
Phaˆn boˆˊ hội quaˊ trıˋnh tieˆˊnq(x1:T∣x0)=t=1∏Tq(xt∣xt−1)
Một tính chất đáng chú ý của quá trình tiến, vì chọn nhiễu có phân bố chuẩn, ta có thể tùy tiện lấy mẫu tại một bước nhảy bất kì với:
xt∼N(αtˉxt−1,(1−αtˉ)I)
Với:
αt=1−βt
αtˉ=∏s=1tαs
Bởi vì chúng ta muốn cuối quá trình là phân bố chuẩn tắc nhiều chiều:
q(xT∣x0)≈N(0,I)
Hay nói cách khác
αˉT≈0
Việc chọn phương sai lịch trình β và T bước nhảy là không khó để thỏa mãn điều kiện trên.
2.2 Quá trình lùi (Reverse Process)
Trở về câu hỏi ban đầu, liệu ta có thể mô hình hóa p(xt−1∣xt). Như thường lệ… cái hay cái đẹp sẽ không bao giờ đến một cách dể dàng…
Nếu mô phỏng đủ nhiều, bạn sẽ thấy phân bố này quá phức tạp để có thể mô hình hóa. Mà nếu, nếu có mô hình hóa được, thì cũng rất khó tối ưu, thật đấy thử đi.
Một lời giải hay không phải là một lời giải chính xác tuyệt đối, mà một lời giải đủ khéo léo và tinh vi…
Gọi p(xt−1∣xt) là hàm mật độ lùi có phân bố chuẩn nhiều chiều chưa biết tham số phân bố:
p(xt−1∣xt)=N(xt−1∣?,?)
Tuy vậy, tham số của phân bố vẫn chưa rõ. Nếu để ý, ta có thể thấy phân bố của mỗi điểm phụ thuộc vào vị trí xt và bước nhảy thời gian t.
…hmmm, với sức mạnh tính toán của các mô hình học sâu, chúng ta hoàn toàn có thể xấp xĩ các tham số này! Dường như đây là mảnh ghép quan trọng còn thiếu trong bài toán này.
Phân bố chuẩn nhiều chiều có 2 tham số cần tìm là trung bình và ma trận hiệp phương sai. Tuy vậy nếu càng ít tham số phải tính thì bài toán sẽ càng dể hơn nhiều lần. [Ho et al., 2020] cho thấy chúng ta có thể chọn ma trận hiệp phương sai là βt.
Gọi μθ(xt,t) là trung bình được xấp xĩ của mạng học sâu mà ta sẽ xây dựng. Ta có:
pθ(xt−1∣xt)=N(xt−1∣μθ(xt,t),βt)
Lúc này phân bố hội qúa trình lùi có thể viết lại:
Phaˆn boˆˊ hội quaˊ trıˋnh luˋipθ(x0:T)=pθ(xT)t=1∏Tpθ(xt−1∣xt)
2.3 Huấn luyện mô hình
Trước hết, khi vào bài toán này ta chỉ có điểm dữ liệu tại thời gian ban đầu x0 (biết trước), mọi trình tự ngẫu nhiên phía sau x1:T được sinh ra từ phân bố quá trình tiến q(.∣x0). Như anh chàng thám tử trong thế giới ngẫu nhiên, ta thử dùng bất đẳng thức ELBO (Evidence lower bound lnp(a)≥Eb∼q(.∣a)[lnq(b∣a)p(a,b)] ), ta có:
Vì sao không tối ưu trực tiếp loglikelihood của pθ(x0) ? Đơn giản là vì nó khó lấy trong bài toán này.
Vì ta muốn áp dụng cho toàn bộ tập dữ liệu, lấy kì vọng vế trái và phải ta có (lưu ý: kỳ vọng của kỳ vọng viết đơn giản thành kỳ vọng), ta sẽ tối ưu vế phải bài toán này, vế phải càng nhỏ thì vế trái càng nhỏ:
E[−lnpθ(x0)]≤−Eq[lnq(x1:T∣x0))pθ(x0:T)]
Liệu bạn có cảm nhận được điều gì không hửm thám tử SherLog Holmes? Vế phải tử thức là phân bố hội quá trình lùi, mẫu thức là phân bố hội quá trình tiến? Liệu công thức ý ẩn ý gì chăng?
Nếu tiếp tục đơn giản hóa công thức, bạn sẽ nhận ra vế phải là tổng của KL-divergence của từng bước nhảy thời gian quá trình tiến và quá trình lùi, và đằng sau đó là công thức tường minh đẹp sau lớp mặt nạ ngụy trang. Nào hãy cùng vạch trần… công thức…
Sỡ dĩ ở đây chúng ta có thể phân tích ra thành KL Divergence bởi vì bản chất KL Divergence là kỳ vọng lograrít giữa hai phân bố.
Nhận thấy:
LT: không có tham số tối ưu. Có thể lược giản.
L0: bình thường công thức này trong ELBO chúng ta có thể xấp xĩ bằng cách lấy mẫu ngẫu nhiên Monte Carlo. Tuy vậy nhóm tác giả đã gộp vào công thức tính của Lt−1
Lt−1: có công thức đẹp : ) vì đó là KL-divergence giữa hai phân bố chuẩn.
Do KL Divergence đang lấy là của hai phân bố chuẩn nhiều chiều, ta có công thức tường minh:
Lt−1=Eq[2σt21∥μt~(xt,x0)−μθ(xt,t)∥2]+C
Với C là hằng số. Do đó ta có thể loại khỏi công thức tối ưu.
Nhưng than ôi, 30 chưa phải là tết, cái kết chưa đến. Nếu là một người làm tối ưu hóa, công thức trên vẫn khó tối ưu. Mặc dù bạn hoàn toàn có thể lấy mẫu xt một cách đơn giản, tuy vậy thách thức lớn nhất là làm sao cho mô hình khả vi tối ưu tham số được với vectơ gradient.
Ta sẽ dùng mẹo tái tham số hóa (reparametrization trick, bạn đọc quan tâm có thể đọc thêm bài viết thú vị [6]). Vì phân bố chuẩn thuộc họ Location–scale family ta có thể viết:
xt(x0,ϵ)=αˉtx0+1−αˉtϵ với ϵ∼N(0,I)
Như đã nói ở quá trình tiến, bởi vì điều thú vị trong bài toán này cho phép ta lấy mẫu tùy tiện, kỳ vọng của quá trình tiến có thể viết μt~(xt,x0)=1−αˉtαˉt−1βtx0+1−αˉtαt(1−αˉt−1)xt .
Bởi vì xt đã có, hơn nữa nếu nhìn vào công thức trên thì mạng μθ phải dự đoán αt1(xt(x0,ϵ)−1−αˉtβtϵ) do đó ta có thể “cố tình” chọn tham số hóa như sau:
μθ(xt,x0)=αt1(xt−1−αˉtβtϵθ(xt,t))
Với ϵθ(xt,t) là hàm xấp xĩ xt khi biết xt. Khi đó việc lấy mẫu xt−1∼pθ(xt−1∣xt) được tính toán:
xt−1=αt1(xt−1−αˉtβtϵθ(xt,t))+σtz với z∼N(0,I)
Với công thức trên t=1 ứng với L0 (trong bài viết này không chứng minh, bạn có thể đọc thêm tài liệu [1] phần 3.3), với t>1 tương ứng Lt−1 không trọng số.
2.4 Mã giả huấn luyện mô hình
Trong mã nguồn bên dưới bạn có thể tìm đến phần comment tương ứng trong source code bằng cách tìm:
[Pseudocode T3]
[Pseudocode T4]
[Pseudocode T5]
[Pseudocode T7]
T1 | FUNCTION Train-Diffusion-Probabilistic-Model: T2 | Repeat { T3 | x0∼q(x0) T4 | t∼Uniform({1,...,T}) T5 | ϵ∼N(0,I) T6 | Lấy vectơ gradient và cập nhật tham số T7 | ∇θϵ−ϵθ(αˉtx0+1−αˉtϵ,t)2 T8 | } Until (hội tụ) T9 | END
2.4 Mã giả lấy mẫu
Trong mã nguồn bên dưới bạn có thể tìm đến phần comment tương ứng trong source code bằng cách tìm:
[Pseudocode S2]
[Pseudocode S3]
[Pseudocode S4]
[Pseudocode S5]
S1 | FUNCTION Sample-Diffusion-Probabilistic-Model: S2 | xT∼N(0,I) S3 | For t=T,...,1 do { S4 | z∼N(0,I) if t>1, else z=0 S5 | xt−1=αt1(xt−1−αˉtβtϵθ(xt,t))+σtz S6 | } End For S7 | RETURN x0
3. Cài đặt DDPM
Phần mã nguồn đưới đây được cài đặt với PyTorch cho dữ liệu giả lập trong không gian 2 chiều để dể hình dung về mặt lý thuyết. Bạn đọc quan tâm với dữ liệu ảnh có thể đọc thêm ở phần 4.
import tqdm
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.datasets import make_moons, make_swiss_roll
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KernelDensity
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
plt.style.use("bmh")
plt.rcParams.update({'font.size':18})
device ='cpu'
dtype = torch.float32
np_dtype = np.float32
n_steps =100
n_epochs =300
Phần lớn nhóm mô hình thuộc họ DDPM dùng phục vụ cho các bài toán ảnh tạo sinh. Stable Diffusion là một trong những mô hình thành công khi xây dựng Diffusion Process kết hợp với kiến trúc mạng UNet.
Một số bài hướng dẫn trên mạng cho tập dữ liệu ảnh, bạn đọc quan tâm có thể tìm hiểu qua đường dẫn sau:
Phần giải thích trong bài viết của ThetaLog vẫn chưa đề cập đến kỹ thuật Conditional Diffusion Models, trong nhiều bài toán bạn sẽ muốn kiểm soát được ngữ cảnh lấy mẫu, việc sinh ra mẫu từ một mô tả ngữ cảnh nếu bạn đọc quan tâm có thể tìm hiểu thêm (ứng dụng trong bài toán tạo sinh từ mô tả văn bản).