Categories
程式開發

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架


本文要點

  • 開發人員可以使用 Java 和他們喜歡的 IDE 來構建、訓練和部署機器學習(ML)和深度學習(DL)模型
  • DJL 簡化了深度學習(DL)框架的使用,目前支持 Apache MXNet
  • DJL 的開源對於工具包及其用戶來說都是互惠互利的
  • DJL 是引擎無關的,這意味著開發人員只需編寫一次代碼就可以在任何引擎上運行
  • 在嘗試使用 DJL 之前,Java 開發人員應該了解 ML 生命週期和常用的 ML 術語

亞馬遜(Amazon)的 DJL(Deep Java Library )是一個深度學習工具包,使用它可在 Java 中原生地進行機器學習(ML)和深度學習(DL)模型開發,從而簡化深度學習框架的使用。 DJL 是在 2019 年 re:Invent 大會上開源的工具包,它提供了一組高級 API 來訓練、測試和運行在線推理(inference)。 Java 開發人員可以開發自己的模型,也可以在他們的 Java 代碼中使用數據科學家用 Python 開發的預先訓練的模型。

DJL 秉承了 Java 的座右銘,“編寫一次,到處運行(WORA)”,因為它是引擎和深度學習框架無關的。開發人員只需編寫一次就可在任何引擎上運行。 DJL 目前提供了一個 Apache MXNet 的實現,這是一個可以簡化深度神經網絡開發的 ML 引擎。 DJL API 使用 JNA(Java Native Access)來調用相應的 Apache MXNet 操作。 DJL 編排管理基礎設施,基於硬件配置來提供自動的 CPU/GPU 檢測,以確保良好的運行效果。

DJL API 通過抽象常用的功能來開發模型,這使 Java 開發人員能夠利用現有的知識,從而可以輕鬆地過渡到 ML。為了了解 DJL 的實際效果,我們開發一個“鞋”的分類模型作為一個簡單的示例。

機器學習生命週期

我們建立“鞋”分類模型遵循了機器學習的生命週期。 ML 生命週期與傳統的軟件開發生命週期有所不同,它包含六個具體的步驟:

  1. 獲取數據
  2. 清洗並準備數據
  3. 生成模型
  4. 評估模型
  5. 部署模型
  6. 從模型中獲得預測(或推理)

生命週期的最終結果是一個可以查詢並返回答案(或預測)的機器學習模型。

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 1

模型只是數據中趨勢和模式的數學表示。好的數據才是所有 ML 項目的基礎。

在步驟 1 中,從可靠的來源中獲取數據。在步驟 2 中,數據被清洗、轉換並以機器可以學習的格式存儲。清洗和轉換過程通常是機器學習生命週期中最耗時的部分。 DJL 提供了利用翻譯器(translator)來對圖像進行預處理的能力,這能為開發人員簡化清洗和轉換過程。翻譯器可以執行一些圖像任務,比如,可以根據預設參數調整圖像的大小或將圖像從彩色圖轉換為灰度圖。

剛剛過渡向機器學習的開發人員常常會低估清洗和轉換數據所需的時間,因此翻譯器是快速啟動該過程的好方法。步驟 3,在訓練過程中,一個機器學習算法會對數據進行多遍(或多代)處理,不斷研究它們,以試圖學習到不同類型的“鞋”。訓練過程中發現的與“鞋”相關的趨勢和模式會被存儲在模型中。當需要評估模型以確定其在識別“鞋”方面的能力時,第 4 步會作為訓練的一部分;如果發現了錯誤,則予以糾正。在步驟 5 中,將模型部署到生產環境中。模型投入生產後,步驟 6 允許其他系統使用該模型。

通常,可以在代碼中動態地加載模型,或者通過基於 REST 的 HTTPS 端點訪問模型。

數據

“鞋”分類模型是一個多級分類計算機視覺(CV)模型,它使用有監督學習進行訓練,可以將“鞋”分為四類:靴子(boots)、涼鞋(sandals)、鞋子(shoes)或拖鞋(slippers)。有監督學習必須包含已經標記了我們想要預測的目標(或答案)的數據;這就是機器學習的方式。

“鞋”分類模型的數據源是德克薩斯大學奧斯汀分校(The University of Texas at Austin)提供的 UTZappos50k 數據集(dataset),它可免費用於學術和非商業用途。下面這個“鞋子”數據集包含了從 Zappos.com 收集的 50025 張帶標籤的目錄圖像。

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 2

“鞋”數據保存在本地,並使用 DJL 的 ImageFolder 數據集對其進行加載,該數據集可以從本地文件夾中檢索圖像。

// 识别训练数据的位置
String trainingDatasetRoot = "src/test/resources/imagefolder/train";

// 识别验证数据的位置
String validateDatasetRoot = "src/test/resources/imagefolder/validate";

// 创建训练数据 ImageFolder 数据集
ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

//创建验证数据 ImageFolder 数据集
ImageFolder validateDataset = initDataset(validateDatasetRoot);

在本地構造數據時,我並沒有深入到UTZappos50k 數據集所標識的最細粒度的分類等級,比如到腳踝的、膝蓋等高的、到達小腿中部的、過膝的等靴子的最細粒度等級的分類標籤。我的本地數據使用的是最高等級的分類,僅包括靴子、涼鞋、鞋子和拖鞋等四類。

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 3

在 DJL 術語中,數據集只用於保存訓練數據。有些數據集的實現可用於下載數據(基於我們提供的 URL)、提取數據、以及自動地將數據分為訓練集和驗證集。

自動分離是一個特別有用的特性,因為不使用相同的數據來訓練和驗證模型這一點是至關重要的。該模型所使用的訓練數據集用於查找“鞋”數據中的趨勢和模式。驗證數據集通過提供對“鞋”分類模型精度無偏差的估計來檢驗模型的效果。

如果用訓練的數據驗證模型,則會降低我們對模型分類鞋子能力的信心,因為模型是用它已經看到的數據進行測試的。在現實世界中,老師也不會使用和學習指南上完全相同的題目來測試學生,因為這不能衡量一個學生的真實知識或對資料的理解;當然,同樣的概念也適用於機器學習模型。

訓練

現在我們已經將“鞋”數據分為訓練集和驗證集,下面我們將使用神經網絡來訓練(或生成)模型。

public final class Training extends AbstractTraining {

     . . .

     @Override
     protected void train(Arguments arguments) throws IOException {

          // 识别训练数据的位置
          String trainingDatasetRoot = "src/test/resources/imagefolder/train";

          // 识别验证数据的位置
          String validateDatasetRoot = "src/test/resources/imagefolder/validate";

          //创建训练数据 ImageFolder 数据集
          ImageFolder trainingDataset = initDataset(trainingDatasetRoot);

          //创建验证数据 ImageFolder 数据集
          ImageFolder validateDataset = initDataset(validateDatasetRoot);

          . . .
          
          try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
             TrainingConfig config = setupTrainingConfig(loss);

             try (Trainer trainer = model.newTrainer(config)) {
                 trainer.setMetrics(metrics);

                 trainer.setTrainingListener(this);

                 Shape inputShape = new Shape(1, 3, NEW_HEIGHT, NEW_WIDTH);

                 // 根据相应输入的形状初始化训练器
                 trainer.initialize(inputShape);

                 //在数据中查找模式
                 fit(trainer, trainingDataset, validateDataset, "build/logs/training");

                 //设置模型属性
                 model.setProperty("Epoch", String.valueOf(EPOCHS));
                 model.setProperty("Accuracy", String.format("%.2f", getValidationAccuracy()));

                // 训练完成后保存模型,为后面的推理做准备
                //模型保存为 shoeclassifier-0000.params
                model.save(Paths.get(modelParamsPath), modelParamsName);
             }
          }
     }

 }

第一步是通過調用 Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 來獲取模型實例。深度學習是機器學習的一種形式,它使用神經網絡來訓練模型。神經網絡是以人腦中的神經元來進行建模的;神經元是可以將信息(或數據)傳遞給其他細胞的細胞。

ResNet-50 是一種常用於圖像分類的神經網絡,50 表示從初始輸入數據和最終預測之間有 50 個學習層(或神經元)。 getModel() 方法用於創建一個空模型,構造一個 ResNet-50 神經網絡,並將神經網絡設置到該模型中。

public class Models {
   public static ai.djl.Model getModel(int numOfOutput, int height, int width) {
       //创建一个空模型的新实例
       ai.djl.Model model = ai.djl.Model.newInstance();

       //是构建神经网络所需的可组合单元;可以像像乐高积木一样将它们连结在一起,
       //形成一个复杂的网络
       Block resNet50 =
               //构建网络
               new ResNetV1.Builder()
                       .setImageShape(new Shape(3, height, width))
                       .setNumLayers(50)
                       .setOutSize(numOfOutput)
                       .build();

       //将神经网络设置到模型中
       model.setBlock(resNet50);
       return model;
   }
}

下一步是通過調用 model.newTrainer(config) 方法來設置和配置訓練器。通過調用 setupTrainingConfig(loss) 方法來初始化配置對象,該方法通過設置訓練的配置(或超參)來決定如何訓練網絡。

接下來的步驟使我們可以通過設置以下內容來向 Trainer 中添加功能:

  • 使用 trainer.setMetrics(metrics) 來設置 Metrics
  • 使用 trainer.setTrainingListener(this) 來設置訓練監聽器
  • 使用 trainer.initialize(inputShape) 來設置合適的輸入形狀

Metrics 在訓練期間收集並報告關鍵績效指標(KPI),該 KPI 可用於分析和監控訓練的效果和穩定性。下一步是通過調用 fit(trainer, trainingDataset, validateDataset, “build/logs/training”) 方法來啟動訓練過程,該方法將迭代訓練數據並存儲在模型中找到的模式。訓練結束時,使用 model.save(Paths.get(modelParamsPath) 方法將一個表現良好的、經過驗證的模型工件及屬性保存在本地。

訓練過程中報告的度量指標如下所示。注意,隨著每代(epoch)(或每遍(pass))的遞增,模型的精度都會提高;第9代(epoch)的最終訓練精度為90%。

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 4

推理

現在我們已經生成了模型,它可以用於對我們不知道類型(或目標)的新數據執行推理(或預測)。

private Classifications predict() throws IOException, ModelException, TranslateException  {
   //在训练期间保存到模型的位置
   String modelParamsPath = "build/logs";

   //训练时设置的模型名称
   String modelParamsName = "shoeclassifier";

   //需要分类的图像路径
   String imageFilePath = "src/test/resources/slippers.jpg";

   //从路径加载图像文件
   BufferedImage img = BufferedImageUtils.fromFile(Paths.get(imageFilePath));

   //持有每个标签的概率分数
   Classifications predictResult;

   try (Model model = Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH)) {
       //加载模型
       model.load(Paths.get(modelParamsPath), modelParamsName);

       //定义用于预处理和后置处理的翻译器
       Translator translator = new MyTranslator();

       //使用预测器运行推理
       try (Predictor predictor = model.newPredictor(translator)) {
           predictResult = predictor.predict(img);
       }
   }

   return predictResult;
}

在設置了模型和要分類的圖像的必要路徑之後,使用Models.getModel(NUM_OF_OUTPUT, NEW_HEIGHT, NEW_WIDTH) 方法獲取一個空模型實例,並使用model.load(Paths.get(modelParamsPath), modelParamsName) 方法對其進行初始化。它將會加載上一步訓練的模型。

接下來,使用 model.newPredictor(translator) 方法初始化一個帶有指定的 Translator 的 Predictor。在 DJL 術語中,Translator 提供了模型預處理和置後處理的能力。例如,對於 CV 模型,需要將圖像重塑為灰度圖;Translator 是可以做到的。 Predictor 使我們可以利用predictor.predict(img) 方法來對加載的 Model 進行推理,並傳入圖像進行分類。

這個示例展示的是單個的預測,但是 DJL 也支持批量預測。推理存儲在 predictResult 中,predictResult 包含了每個標籤的概率估計。

推理(每張圖片)及其對應的概率得分如下所示。

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 5

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 6

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 7

Deep Java Library (DJL) 簡介:與引擎無關的Java深度學習框架 8

(表格對應的圖片如上所示)

圖像 概率得分
如圖1 (信息) – (                 分類: “0”, 概率: 0.98985 分類: “1”, 概率: 0.00225                 分類: “2”, 概率: 0.00224                 分類: “3”, 概率: 0.00564             ) 分類0 代表靴子,概率得分為 98.98%
圖2 (信息) – (                分類: “0”, 概率: 0.02111                分類: “1”, 概率: 0.76524 分類: “2”, 概率: 0.01159                分類: “3”, 概率: 0.20204           ) 分類1 代表涼鞋,概率得分為 o76.52%
圖3 (信息) – (                分類: “0”, 概率: 0.05523                分類: “1”, 概率: 0.01417                分類: “2”, 概率: 0.87900 分類: “3”, 概率: 0.05158               ) 分類2 代表鞋子,概率得分為 87.90%
圖4 (信息) – (                 分類: “0”, 概率: 0.00003                 分類: “1”, 概率: 0.01133                分類: “2”, 概率: 0.00179                 分類: “3”, 概率: 0.98682 ) 分類3 代表拖鞋,概率得分為of 98.68%.

DJL 提供了與其他 Java 庫一樣的原生 Java 開發體驗和功能。設計這些 API 是為了指導開發人員能夠用最佳實踐來完成深度學習任務。在開始使用 DJL 之前,需要對 ML 生命週期有一個很好的理解。如果您是 ML 初學者,請先閱讀這篇概述或 InfoQ的系列文章《軟件開發人員機器學習入門》。在理解了生命週期和常見的ML術語之後,開發人員就可以快速地掌握 DJL 的 API 了。

亞馬遜已經開源了 DJL,有關該工具包的更多詳細信息可以在 DJL 網站Java 庫 API 規範(Java Library API Specification) 頁面上找到。您也可以回顧下“鞋”分類模型的代碼,以進一步探索該示例。

作者介紹

Kesha Williams 是一位屢獲殊榮的軟件工程師、機器學習實踐者和 A Cloud Guru 的技術講師,擁有24年的經驗。在大學任教期間,她曾培訓並指導了數千名來自美國、歐洲和亞洲的 Java 軟件工程師。她經常帶領創新團隊驗證新興技術,並在全球各地的會議上分享她的經驗教訓。作為TED 的 Spotlight Presentation Academy 的獲得者,她在 TED 舞台上做過機器學習的演講。此外,她在人工智能領域的開創性工作為她贏得了亞馬遜的 Alexa Champion 和 AWS Machine Learning Hero的殊榮。在業餘時間,她通過在線社交專業網絡平台 Colors of STEM 指導女性科技從業者。

原文鏈接:

Getting to Know Deep Java Library (DJL)