文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

深度学习模型的Android部署方法

2023-09-29 18:42

关注

使用背景:

将python中训练的深度学习模型(图像分类、目标检测、语义分割等)部署到Android中使用。


Step1:下载并集成Pytorch Android库

下载Pytorch Android库。
在Pytorch的官网pytorch.org上找到最新版本的库。下载后,将其解压缩到项目的某个目录下。

配置项目gradle文件
配置项目的gradle文件,向项目添加Pytorch Android库的依赖项。打开项目的build.gradle文件,添加以下代码:

repositories {    // 添加以下两行代码    maven {        url "https://oss.sonatype.org/content/repositories/snapshots/"    }}dependencies {    // 添加以下两行代码    implementation 'org.pytorch:pytorch_android:1.8.0-SNAPSHOT'    implementation 'org.pytorch:pytorch_android_torchvision:1.8.0-SNAPSHOT'}

将库文件添加到项目中
将Pytorch Android库的库文件添加到项目中。可以将其复制到“libs”文件夹中,并在项目的gradle文件中添加以下代码:

android {    sourceSets {        main {            jniLibs.srcDirs = ['libs']        }    }}

配置NDK版本
确保项目使用了支持Pytorch Android库的NDK版本。打开项目的local.properties文件,添加以下代码:

//NDK目录ndk.dir=/path/to/your/ndk 

同步gradle文件
在Android Studio中,点击“Sync Project with Gradle Files”按钮,等待同步完成。
到这就集成了Pytorch Android库。可以在应用程序中使用Pytorch Android库提供的API加载模型文件并进行预测。


Step2:准备Pytorch导出的.pt模型文件

假如我们的深度学习模型输入图片大小尺寸为(640,640,3),并且已经在python中训练好了my_model.pth,那么我们需要将其转换为.pt格式:

import torch# 加载PyTorch模型model = torch.load("my_model.pth")# 将PyTorch模型转换为TorchScript格式traced_script_module = torch.jit.trace(model, torch.randn(1, 3, 640, 640))traced_script_module.save("my_model.pt")

转换Pytorch模型为TorchScript格式时,需要确保使用的所有操作都是TorchScript支持的。否则,在转换模型时可能会出现错误。

Step3:导入Pytorch模型文件

要在Android Studio中创建新项目并将m_model.pt模型文件放入该项目中,包含以下步骤:
1、打开Android Studio,并选择“Create New Project”选项。
2、在“Create New Project”向导中,输入项目名称,选择项目保存位置,并选择“Phone and Tablet”作为您的应用程序目标设备。然后,单击“Next”继续。
3、选择“Empty Activity”模板,并单击“Next”继续。
4、在“Configure Activity”对话框中,输入活动名称并单击“Finish”完成项目创建过程。
5、在项目中创建一个名为“assets”的文件夹。要创建该文件夹,请右键单击项目根目录,选择“New” -> “Folder” -> “Assets Folder”。
6、将m_model.pt模型文件复制到“assets”文件夹中。要将文件复制到“assets”文件夹中,右键单击该文件夹,选择“Show in Explorer”或“Show in Finder”,然后将文件复制到打开的文件夹中。
7、在代码中加载模型文件使用以下代码示例加载模型文件:

AssetManager assetManager = getAssets();String modelPath = "m_model.pt";File modelFile = new File(getCacheDir(), modelPath);try (InputStream inputStream = assetManager.open(modelPath);     FileOutputStream outputStream = new FileOutputStream(modelFile)) {    byte[] buffer = new byte[4 * 1024];    int read;    while ((read = inputStream.read(buffer)) != -1) {        outputStream.write(buffer, 0, read);    }    outputStream.flush();} catch (IOException e) {    e.printStackTrace();}// 加载PyTorch模型Module model = Module.load(modelFile.getAbsolutePath());

在这里需要注意将模型文件保存到应用程序的缓存目录中,而不是将其保存在项目资源中。这是因为在运行时,Android应用程序不能直接读取项目资源,而是需要使用AssetManager类从“assets”文件夹中读取文件。


Step4:模型的调用及使用示例

接下来示例运行模型、获取模型输出和在主线程中更新UI的代码:

import org.pytorch.IValue;import org.pytorch.Module;import org.pytorch.Tensor;import org.pytorch.torchvision.TensorImageUtils;import android.content.res.AssetManager;import android.graphics.Bitmap;import android.graphics.BitmapFactory;import android.os.Bundle;import android.os.Handler;import android.os.Looper;import android.util.Log;import androidx.appcompat.app.AppCompatActivity;import androidx.camera.core.CameraX;import androidx.camera.core.ImageAnalysis;import androidx.camera.core.ImageProxy;import androidx.camera.core.Preview;import androidx.camera.lifecycle.ProcessCameraProvider;import androidx.camera.view.PreviewView;import androidx.core.content.ContextCompat;import androidx.lifecycle.LifecycleOwner;import java.io.IOException;import java.io.InputStream;import java.util.concurrent.ExecutorService;import java.util.concurrent.Executors;public class MainActivity extends AppCompatActivity {    private static final String MODEL_PATH = "m_model.pt";    private static final int INPUT_SIZE = 224;    private Module mModule;    private ExecutorService mExecutorService;    private Handler mHandler;    @Override    protected void onCreate(Bundle savedInstanceState) {        super.onCreate(savedInstanceState);        setContentView(R.layout.activity_main);        // 加载PyTorch模型和创建执行线程池        loadModel();        // 创建主线程处理程序        mHandler = new Handler(Looper.getMainLooper());        // 启动相机        startCamera();    }    private void loadModel() {        // 加载PyTorch模型        try {            AssetManager assetManager = getAssets();            InputStream inputStream = assetManager.open(MODEL_PATH);            mModule = Module.load(inputStream);        } catch (IOException e) {            Log.e("MainActivity", "Error reading model file: " + e.getMessage());            finish();        }        // 创建执行线程池        mExecutorService = Executors.newSingleThreadExecutor();    }    private void startCamera() {        // 创建PreviewView        PreviewView previewView = findViewById(R.id.preview_view);        // 配置相机生命周期所有者        LifecycleOwner lifecycleOwner = this;        // 配置相机预览        Preview preview = new Preview.Builder().build();        preview.setSurfaceProvider(previewView.getSurfaceProvider());        // 配置图像分析        ImageAnalysis imageAnalysis =                new ImageAnalysis.Builder()                        .setBackpressureStrategy(ImageAnalysis.STRATEGY_KEEP_ONLY_LATEST)                        .build();        // 设置图像分析的处理程序        imageAnalysis.setAnalyzer(                mExecutorService,                new ImageAnalysis.Analyzer() {                    @Override                    public void analyze(ImageProxy image, int rotationDegrees) {                        // 将ImageProxy转换为Bitmap                        Bitmap bitmap =    Bitmap.createScaledBitmap(            image.getImage(),            INPUT_SIZE,            INPUT_SIZE,            false);                        // 将Bitmap转换为Tensor                        Tensor tensor =    TensorImageUtils.bitmapToFloat32Tensor(            bitmap,            TensorImageUtils.TORCHVISION_NORM_MEAN_RGB,            TensorImageUtils.TORCHVISION_NORM_STD_RGB);                        // 创建输入列表                        final IValue[] inputs = {IValue.from(tensor)};                        // 运行模型                        Tensor outputTensor = mModule.forward(inputs).toTensor();                        // 获取模型输出                        float[] scores = outputTensor.getDataAsFloatArray();                        // 查找最高分数                        float maxScore = -Float.MAX_VALUE;                        int maxScoreIndex = -1;                        for (int i = 0; i < scores.length; i++) {if (scores[i] > maxScore) {    maxScore = scores[i];    maxScoreIndex = i;}                        }                        // 获取分类标签                        String[] labels = getLabels();                        String predictedLabel = labels[maxScoreIndex];                        // 更新UI                        updateUI(predictedLabel);                    }                });        // 绑定相机生命周期所有者        CameraX.bindToLifecycle(lifecycleOwner, preview, imageAnalysis);    }    private String[] getLabels() {        // 在此处替换为标签文件        return new String[]{                "tench",                "goldfish",                "great white shark",                "tiger shark",                // ...        };    }    private void updateUI(String predictedLabel) {        mHandler.post(                new Runnable() {                    @Override                    public void run() {                        // 更新UI                        // 例如,将预测标签写入TextView                        // TextView textView = findViewById(R.id.text_view);                        // textView.setText(predictedLabel);                    }                });    }    @Override    protected void onDestroy() {        super.onDestroy();        // 释放模型和执行线程池        mModule.destroy();        mExecutorService.shutdown();    }}

当模型预测输入图像时,它将返回一个整数,该整数表示模型预测的图像类型的索引。可以使用该索引来查找对应的标签并更新UI。例如:

// 查找最高分数float maxScore = -Float.MAX_VALUE;int maxScoreIndex = -1;for (int i = 0; i < scores.length; i++) {    if (scores[i] > maxScore) {        maxScore = scores[i];        maxScoreIndex = i;    }}// 获取分类标签String[] labels = getLabels();String predictedLabel = labels[maxScoreIndex];// 更新UIupdateUI(predictedLabel);

Step5:调试程序

编译和运行应用程序,并在Android Studio调试上测试图像识别功能。

来源地址:https://blog.csdn.net/qq_45193872/article/details/130210414

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-移动开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯