本文以 OjectDetection 例子为子(市面上一个很火的360智能跟拍云台)展开说明,TensorFlow Lite可以与 Android 8.1 中发布的神经网络 API 完美配合,即便在没有硬件加速时也能调用 CPU 处理,确保模型在不同设备上的运行。
整个工程大致的过程就是从控件 textureView 中以指定的长宽读取一个 Bitma p出来(也就是摄像头的实时画面),然后交给 classifier 的 classifyFrame 进行处理,返回一个结果,这个结果就是物体检测的结果,然后显示在手机屏幕上。
一、环境的搭建
我们可以使用 Android Studio 创建一个 Android 项目,一路默认就可以了,并不需要 C++ 的支持,因为是拿人家训练好的模型直接来用,不用去训练模型,即用到的 TensorFlow Lite 是 Java 代码的,开发起来非常方便。但需要特别的功能,就需要使用 TensorFlow 去训练模型了。
1.1 依赖
创建完成之后,在 app 目录下的 build.gradle 配置文件加上以下配置信息,如在 dependencies 下加上包的引用(每次运行都下载依赖):
//依赖库
implementation('org.tensorflow:tensorflow-lite:0.0.0-nightly') { changing = true }
对于 Android 有一个地方需要注意,必须在 app 模块的 build.gradle 中添加如下的语句,否则无法加载模型。
//set no compress models
aaptOptions {
noCompress "tflite"
}
1.2 模型文件配置
在 main 目录下创建 assets 文件夹,这个文件夹主要是存放 tflite 模型和 label 名称文件。
TensorFlow Lite 提供了 C ++ 和 Java 两种类型的 API。无论哪种 API 都需要加载模型和运行模型。而 TensorFlow Lite 的 Java API 使用了 Interpreter 类(解释器)来完成加载模型和运行模型的任务。
二、原始数据的获取
手机端的深度学习输入参数有视觉和听觉,即图像和声音,对于图像而言, Camera 是图像采集的唯一工具。因此需要了解 Camera2 的几个比较重要的类:
- CameraManager: 管理手机上的所有摄像头设备,它的作用主要是获取摄像头列表和打开指定的摄像头;
- CameraDevice: 具体的摄像头设备,它有一系列参数(预览尺寸、拍照尺寸等),可以通过 CameraManager 的 getCameraCharacteristics() 方法获取。它的作用主要是创建 CameraCaptureSession 和 CaptureRequest;
- CameraCaptureSession: 相机捕获会话,用于处理拍照和预览的工作(很重要);
- CaptureRequest: 捕获请求,定义输出缓冲区以及显示界面(TextureView 或 SurfaceView)等。
数据获取的过程:通过Camera 获取图片,然后使用对图片进行压缩,之后把图片转换成 ByteBuffer 格式的数据。
2.1 定义 AutoFitTextureView 作为预览界面
在布局文件中加入 AutoFitTextureView 控件,然后实现其监听事件
textureView = (AutoFitTextureView) view.findViewById(R.id.texture);
然后我们可以在OnResume()方法中设置监听 SurefaceTexture 的事件
textureView.setSurfaceTextureListener(surfaceTextureListener);
当SurefaceTexture准备好后会回调SurfaceTextureListener 的onSurfaceTextureAvailable()方法
TextureView.SurfaceTextureListener textureListener = new TextureView.SurfaceTextureListener() {
@Override
public void onSurfaceTextureAvailable(SurfaceTexture surface, int width, int height) {
//当SurefaceTexture可用的时候,设置相机参数并打开相机
setUpCameraOutputs(width, height);
openCamera();
}
};
2.2 设置相机参数
为了更好地预览,我们根据TextureView的尺寸设置预览尺寸,Camera2中使用CameraManager来管理摄像头
private void setUpCameraOutputs(int width, int height) {
final Activity activity = getActivity();
final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
try {
final CameraCharacteristics characteristics = manager.getCameraCharacteristics(cameraId);
//获取StreamConfigurationMap,它是管理摄像头支持的所有输出格式和尺寸
StreamConfigurationMap map = characteristics.get(CameraCharacteristics.SCALER_STREAM_CONFIGURATION_MAP);
//根据TextureView的尺寸设置预览尺寸
mPreviewSize = getOptimalSize(map.getOutputSizes(SurfaceTexture.class), width, height);
} catch (CameraAccessException e) {
e.printStackTrace();
}
}
2.3 开启相机
Camera2 中打开相机也需要通过 CameraManager 类操作。
private void openCamera() {
final Activity activity = getActivity();
final CameraManager manager = (CameraManager) activity.getSystemService(Context.CAMERA_SERVICE);
try {
if (!cameraOpenCloseLock.tryAcquire(2500, TimeUnit.MILLISECONDS)) {
throw new RuntimeException("Time out waiting to lock camera opening.");
}
manager.openCamera(cameraId, stateCallback, backgroundHandler);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
} catch (final InterruptedException e) {
throw new RuntimeException("Interrupted while trying to lock camera opening.", e);
}
}
实现StateCallback 接口,当相机打开后会回调onOpened方法,在这个方法里面开启预览
private final CameraDevice.StateCallback stateCallback =
new CameraDevice.StateCallback() {
@Override
public void onOpened(final CameraDevice cd) {
// This method is called when the camera is opened. We start camera preview here.
cameraOpenCloseLock.release();
cameraDevice = cd;
//开启预览
createCameraPreviewSession();
}
@Override
public void onDisconnected(final CameraDevice cd) {
cameraOpenCloseLock.release();
cd.close();
cameraDevice = null;
}
......
};
2.4 开启相机预览
我们使用 TextureView 显示相机预览数据,Camera2 的预览和拍照数据都是使用 CameraCaptureSession 会话来请求的。
private void createCameraPreviewSession() {
try {
final SurfaceTexture texture = textureView.getSurfaceTexture();
assert texture != null;
//设置TextureView的缓冲区大小
texture.setDefaultBufferSize(previewSize.getWidth(), previewSize.getHeight());
//获取Surface显示预览数据
final Surface surface = new Surface(texture);
//创建CaptureRequestBuilder,TEMPLATE_PREVIEW比表示预览请求
previewRequestBuilder = cameraDevice.createCaptureRequest(CameraDevice.TEMPLATE_PREVIEW);
previewRequestBuilder.addTarget(surface);
// 使用ImageReader间接实现
previewReader =
ImageReader.newInstance(
previewSize.getWidth(), previewSize.getHeight(), ImageFormat.YUV_420_888, 2);
previewReader.setOnImageAvailableListener(imageListener, backgroundHandler);
previewRequestBuilder.addTarget(previewReader.getSurface());
//创建相机捕获会话,第一个参数是捕获数据的输出Surface列表,第二个参数是CameraCaptureSession的状态回调接口,当它创建好后会回调onConfigured方法,第三个参数用来确定Callback在哪个线程执行,为null的话就在当前线程执行
cameraDevice.createCaptureSession(
Arrays.asList(surface, previewReader.getSurface()),
new CameraCaptureSession.StateCallback() {
@Override
public void onConfigured(final CameraCaptureSession cameraCaptureSession) {
// The camera is already closed
if (null == cameraDevice) {
return;
}
//创建捕获请求
captureSession = cameraCaptureSession;
try {
// Auto focus should be continuous for camera preview.
previewRequestBuilder.set(
CaptureRequest.CONTROL_AF_MODE,
CaptureRequest.CONTROL_AF_MODE_CONTINUOUS_PICTURE);
// Flash is automatically enabled when necessary.
previewRequestBuilder.set(
CaptureRequest.CONTROL_AE_MODE, CaptureRequest.CONTROL_AE_MODE_ON_AUTO_FLASH);
// Finally, we start displaying the camera preview.
previewRequest = previewRequestBuilder.build();
//设置反复捕获数据的请求,这样预览界面就会一直有数据显示
captureSession.setRepeatingRequest(
previewRequest, captureCallback, backgroundHandler);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
}
}
@Override
public void onConfigureFailed(final CameraCaptureSession cameraCaptureSession) {
showToast("Failed");
}
},
null);
} catch (final CameraAccessException e) {
LOGGER.e(e, "Exception!");
}
}
2.5 拍照
Camera2 拍照也是通过 ImageReader 来实现的。
首先先做些准备工作,设置拍照参数,如方向、尺寸等
/** Conversion from screen rotation to JPEG orientation. */
private static final SparseIntArray ORIENTATIONS = new SparseIntArray();
static {
ORIENTATIONS.append(Surface.ROTATION_0, 90);
ORIENTATIONS.append(Surface.ROTATION_90, 0);
ORIENTATIONS.append(Surface.ROTATION_180, 270);
ORIENTATIONS.append(Surface.ROTATION_270, 180);
}
/** Callback for Camera2 API */
@Override
public void onImageAvailable(final ImageReader reader) {
// We need wait until we have some size from onPreviewSizeChosen
if (previewWidth == 0 || previewHeight == 0) {
return;
}
if (rgbBytes == null) {
rgbBytes = new int[previewWidth * previewHeight];
}
try {
final Image image = reader.acquireLatestImage();
if (image == null) {
return;
}
if (isProcessingFrame) {
image.close();
return;
}
isProcessingFrame = true;
final Plane[] planes = image.getPlanes();
fillBytes(planes, yuvBytes);
yRowStride = planes[0].getRowStride();
final int uvRowStride = planes[1].getRowStride();
final int uvPixelStride = planes[1].getPixelStride();
imageConverter =
new Runnable() {
@Override
public void run() {
ImageUtils.convertYUV420ToARGB8888(
yuvBytes[0],
yuvBytes[1],
yuvBytes[2],
previewWidth,
previewHeight,
yRowStride,
uvRowStride,
uvPixelStride,
rgbBytes);
}
};
postInferenceCallback =
new Runnable() {
@Override
public void run() {
image.close();
isProcessingFrame = false;
}
};
processImage();
} catch (final Exception e) {
LOGGER.e(e, "Exception!");
Trace.endSection();
return;
}
}
三、TensorFlow Lite 处理
3.1 加载训练模型
loadModelFile()方法是把模型文件读取成MappedByteBuffer,之后给Interpreter类初始化模型
// load infer model
private void loadModel(String model) {
try {
tflite = new Interpreter(loadModelFile(model));
Log.d(TAG, model + " model load success");
//tflite.setNumThreads(4);
} catch (IOException e) {
Log.d(TAG, model + " model load fail");
e.printStackTrace();
}
}
/**
* Memory-map the model file in Assets.
*/
private MappedByteBuffer loadModelFile(String model) throws IOException {
AssetFileDescriptor fileDescriptor = getApplicationContext().getAssets().openFd(model + ".tflite");
FileInputStream inputStream = new FileInputStream(fileDescriptor.getFileDescriptor());
FileChannel fileChannel = inputStream.getChannel();
long startOffset = fileDescriptor.getStartOffset();
long declaredLength = fileDescriptor.getDeclaredLength();
return fileChannel.map(FileChannel.MapMode.READ_ONLY, startOffset, declaredLength);
}
得到一个对象tflite,之后就是使用这个对象来预测图像,同时可以使用这个对象设置一些参数。
我们先分析一下再 assets 目录下面怎么加载的?说白了就是新建一个 Interpreter 对象,就是加载模型。上面的方法都过时了,我们可以找到 Interpreter类,里面你会看到如下的方法
//第一个参数传tflite文件,第二个参数传一个Interpreter静态内部类对象
public Interpreter(@NonNull File modelFile, Interpreter.Options options) {
this.wrapper = new NativeInterpreterWrapper(modelFile.getAbsolutePath(), options);
}
//所以,我们自己项目里面加载模型,用如下方式即可
//file:///android_asset/labelmap.txt, detect.tflite
Interpreter.Options options = new Interpreter.Options();
options.setNumThreads(4);
tflite = new Interpreter(new File(""), options);
3.2 读取文件种分类标签对应的名称
读取文件种分类标签对应的名称,这个文件 labelmap.txt 跟模型一样存放在 assets 目录下,这个文件比较长,里面有对用的文件。
private List<String> resultLabel = new ArrayList<>();
try {
AssetManager assetManager = getApplicationContext().getAssets();
BufferedReader reader = new BufferedReader(new InputStreamReader(assetManager.open("labelmap.txt")));
String readLine = null;
while ((readLine = reader.readLine()) != null) {
resultLabel.add(readLine);
}
reader.close();
} catch (Exception e) {
Log.e("labelCache", "error " + e);
}
3.3 进行检测
执行run方法
tflite.run(in, out);
或
Object[] inputArray = {imgData};
Map<Integer, Object> outputMap = new HashMap<>();
outputMap.put(0, outputLocations);
outputMap.put(1, outputClasses);
outputMap.put(2, outputScores);
outputMap.put(3, numDetections);
Trace.endSection();
// Run the inference call.
tfLite.runForMultipleInputsOutputs(inputArray, outputMap);
显示检测结果
// Show the best detections.
// after scaling them back to the input size.
// You need to use the number of detections from the output and not the NUM_DETECTONS variable declared on top
// because on some models, they don't always output the same total number of detections
// For example, your model's NUM_DETECTIONS = 20, but sometimes it only outputs 16 predictions
// If you don't use the output's numDetections, you'll get nonsensical data
int numDetectionsOutput = Math.min(NUM_DETECTIONS, (int) numDetections[0]); // cast from float to integer, use min for safety
final ArrayList<Recognition> recognitions = new ArrayList<>(numDetectionsOutput);
for (int i = 0; i < numDetectionsOutput; ++i) {
final RectF detection =
new RectF(
outputLocations[0][i][1] * inputSize,
outputLocations[0][i][0] * inputSize,
outputLocations[0][i][3] * inputSize,
outputLocations[0][i][2] * inputSize);
// SSD Mobilenet V1 Model assumes class 0 is background class
// in label file and class labels start from 1 to number_of_classes+1,
// while outputClasses correspond to class index from 0 to number_of_classes
int labelOffset = 1;
recognitions.add(
new Recognition(
"" + i,
labels.get((int) outputClasses[0][i] + labelOffset),
outputScores[0][i], //最大概率或得分最高
detection));
}
- ObjectDetection : Github