谷歌 JAX 深度學習從零開始學

王曉華

  • 谷歌 JAX 深度學習從零開始學-preview-1
  • 谷歌 JAX 深度學習從零開始學-preview-2
  • 谷歌 JAX 深度學習從零開始學-preview-3
谷歌 JAX 深度學習從零開始學-preview-1

買這商品的人也買了...

商品描述

JAX是一個用於高性能數值計算的Python庫,專門為深度學習領域的高性能計算而設計。本書詳解JAX框架深度學習的相關知識,配套示例源碼、PPT課件、數據集和開發環境。 本書共分為13章,內容包括JAX從零開始,一學就會的線性回歸、多層感知機與自動微分器,深度學習的理論基礎,XLA與JAX一般特性,JAX的高級特性,JAX的一些細節,JAX中的捲積,JAX與TensorFlow的比較與交互,遵循JAX函數基本規則下的自定義函數,JAX中的高級包。最後給出3個實戰案例:使用ResNet完成CIFAR100數據集分類,有趣的詞嵌入,生成對抗網絡(GAN)。 本書適合JAX框架初學者、深度學習初學者以及深度學習從業人員,也適合作為高等院校和培訓機構人工智能相關專業的師生教學參考書。

目錄大綱

目    錄

第 1 章  JAX從零開始 1

1.1  JAX來了 1

1.1.1  JAX是什麽 1

1.1.2  為什麽是JAX 2

1.2  JAX的安裝與使用 3

1.2.1  Windows Subsystem for Linux的安裝 3

1.2.2  JAX的安裝和驗證 7

1.2.3  PyCharm的下載與安裝 8

1.2.4  使用PyCharm和JAX 9

1.2.5  JAX的Python代碼小練習:計算SeLU函數 11

1.3  JAX實戰—MNIST手寫體的識別 12

1.3.1  第一步:準備數據集 12

1.3.2  第二步:模型的設計 13

1.3.3  第三步:模型的訓練 13

1.4  本章小結 15

第2章  一學就會的線性回歸、多層感知機與自動微分器 16

2.1  多層感知機 16

2.1.1  全連接層—多層感知機的隱藏層 16

2.1.2  使用JAX實現一個全連接層 17

2.1.3  更多功能的全連接函數 19

2.2  JAX實戰—鳶尾花分類 22

2.2.1  鳶尾花數據準備與分析 23

2.2.2  模型分析—採用線性回歸實戰鳶尾花分類 24

2.2.3  基於JAX的線性回歸模型的編寫 25

2.2.4  多層感知機與神經網絡 27

2.2.5  基於JAX的激活函數、softmax函數與交叉熵函數 29

2.2.6  基於多層感知機的鳶尾花分類實戰 31

2.3  自動微分器 35

2.3.1  什麽是微分器 36

2.3.2  JAX中的自動微分 37

2.4  本章小結 38

第3章  深度學習的理論基礎 39

3.1  BP神經網絡簡介 39

3.2  BP神經網絡兩個基礎算法詳解 42

3.2.1  最小二乘法詳解 43

3.2.2  道士下山的故事—梯度下降算法 44

3.2.3  最小二乘法的梯度下降算法以及JAX實現 46

3.3  反饋神經網絡反向傳播算法介紹 52

3.3.1  深度學習基礎 52

3.3.2  鏈式求導法則 53

3.3.3  反饋神經網絡原理與公式推導 54

3.3.4  反饋神經網絡原理的激活函數 59

3.3.5  反饋神經網絡原理的Python實現 60

3.4  本章小結 64

第4章  XLA與JAX一般特性 65

4.1  JAX與XLA 65

4.1.1  XLA如何運行 65

4.1.2  XLA如何工作 67

4.2  JAX一般特性 67

4.2.1  利用JIT加快程序運行 67

4.2.2  自動微分器—grad函數 68

4.2.3  自動向量化映射—vmap函數 70

4.3  本章小結 71

第5章  JAX的高級特性 72

5.1  JAX與NumPy 72

5.1.1  像NumPy一樣運行的JAX 72

5.1.2  JAX的底層實現lax 74

5.1.3  並行化的JIT機制與不適合使用JIT的情景 75

5.1.4  JIT的參數詳解 77

5.2  JAX程序的編寫規範要求 78

5.2.1  JAX函數必須要為純函數 79

5.2.2  JAX中數組的規範操作 80

5.2.3  JIT中的控制分支 83

5.2.4  JAX中的if、while、for、scan函數 85

5.3  本章小結 89

第6章  JAX的一些細節 90

6.1  JAX中的數值計算 90

6.1.1  JAX中的grad函數使用細節 90

6.1.2  不要編寫帶有副作用的代碼—JAX與NumPy的差異 93

6.1.3  一個簡單的線性回歸方程擬合 94

6.2  JAX中的性能提高 98

6.2.1  JIT的轉換過程 98

6.2.2  JIT無法對非確定參數追蹤 100

6.2.3  理解JAX中的預編譯與緩存 102

6.3  JAX中的函數自動打包器—vmap 102

6.3.1  剝洋蔥—對數據的手工打包 102

6.3.2  剝甘藍—JAX中的自動向量化函數vmap 104

6.3.3  JAX中高階導數的處理 105

6.4  JAX中的結構體保存方法Pytrees 106

6.4.1  Pytrees是什麽 106

6.4.2  常見的pytree函數 107

6.4.3  深度學習模型參數的控制(線性模型) 108

6.4.4  深度學習模型參數的控制(非線性模型) 113

6.4.5  自定義的Pytree節點 113

6.4.6  JAX數值計算的運行機制 115

6.5  本章小結 117

第7章  JAX中的捲積 118

7.1  什麽是捲積 118

7.1.1  捲積運算 119

7.1.2  JAX中的一維捲積與多維捲積的計算 120

7.1.3  JAX.lax中的一般捲積的計算與表示 122

7.2  JAX實戰—基於VGG架構的MNIST數據集分類 124

7.2.1  深度學習Visual Geometry Group(VGG)架構 124

7.2.2  VGG中使用的組件介紹與實現 126

7.2.3  基於VGG6的MNIST數據集分類實戰 129

7.3  本章小結 133

第8章  JAX與TensorFlow的比較與交互 134

8.1  基於TensorFlow的MNIST分類 134

8.2  TensorFlow與JAX的交互 137

8.2.1  基於JAX的TensorFlow Datasets數據集分類實戰 137

8.2.2  TensorFlow Datasets數據集庫簡介 141

8.3  本章小結 145

第9章  遵循JAX函數基本規則下的自定義函數 146

9.1  JAX函數的基本規則 146

9.1.1  使用已有的原語 146

9.1.2  自定義的JVP以及反向VJP 147

9.1.3  進階jax.custom_jvp和jax.custom_vjp函數用法 150

9.2  Jaxpr解釋器的使用 153

9.2.1  Jaxpr tracer 153

9.2.2  自定義的可以被Jaxpr跟蹤的函數 155

9.3  JAX維度名稱的使用 157

9.3.1  JAX的維度名稱 157

9.3.2  自定義JAX中的向量Tensor 158

9.4  本章小結 159

第10章  JAX中的高級包 160

10.1  JAX中的包 160

10.1.1  jax.numpy的使用 161

10.1.2  jax.nn的使用 162

10.2  jax.experimental包和jax.example_libraries的使用 163

10.2.1  jax.experimental.sparse的使用 163

10.2.2  jax.experimental.optimizers模塊的使用 166

10.2.3  jax.experimental.stax的使用 168

10.3  本章小結 168

第11章  JAX實戰——使用ResNet完成CIFAR100數據集分類 169

11.1  ResNet基礎原理與程序設計基礎 169

11.1.1  ResNet誕生的背景 170

11.1.2  使用JAX中實現的部件—不要重復造輪子 173

11.1.3  一些stax模塊中特有的類 175

11.2  ResNet實戰—CIFAR100數據集分類 176

11.2.1  CIFAR100數據集簡介 176

11.2.2  ResNet殘差模塊的實現 179

11.2.3  ResNet網絡的實現 181

11.2.4  使用ResNet對CIFAR100數據集進行分類 182

11.3  本章小結 184

第12章  JAX實戰—有趣的詞嵌入 185

12.1  文本數據處理 185

12.1.1  數據集和數據清洗 185

12.1.2  停用詞的使用 188

12.1.3  詞向量訓練模型word2vec的使用 190

12.1.4  文本主題的提取:基於TF-IDF 193

12.1.5  文本主題的提取:基於TextRank 197

12.2  更多的詞嵌入方法—FastText和預訓練詞向量 200

12.2.1  FastText的原理與基礎算法 201

12.2.2  FastText訓練以及與JAX的協同使用 202

12.2.3  使用其他預訓練參數嵌入矩陣(中文) 204

12.3  針對文本的捲積神經網絡模型—字符捲積 205

12.3.1  字符(非單詞)文本的處理 206

12.3.2  捲積神經網絡文本分類模型的實現—conv1d(一維捲積) 213

12.4  針對文本的捲積神經網絡模型—詞捲積 216

12.4.1  單詞的文本處理 216

12.4.2  捲積神經網絡文本分類模型的實現 218

12.5  使用捲積對文本分類的補充內容 219

12.5.1  中文的文本處理 219

12.5.2  其他細節 222

12.6  本章小結 222

第13章  JAX實戰—生成對抗網絡(GAN) 223

13.1  GAN的工作原理詳解 223

13.1.1  生成器與判別器共同構成了一個GAN 224

13.1.2  GAN是怎麽工作的 225

13.2  GAN的數學原理詳解 225

13.2.1  GAN的損失函數 226

13.2.2  生成器的產生分佈的數學原理—相對熵簡介 226

13.3  JAX實戰—GAN網絡 227

13.3.1  生成對抗網絡GAN的實現 228

13.3.2  GAN的應用前景 232

13.4  本章小結 235

附錄  Windows 11安裝GPU版本的JAX 236