investigating why the input to the model is incorrect resulting in error on initialization.

This commit is contained in:
Lucas Oskorep
2022-06-21 23:41:42 -04:00
parent dc4bf39c14
commit ebfbfb503d
10 changed files with 340 additions and 17 deletions
+38
View File
@@ -0,0 +1,38 @@
@echo off
setlocal enableextensions
cd %~dp0
set TF_VERSION=2.5
set URL=https://github.com/am15h/tflite_flutter_plugin/releases/download/
set TAG=tf_%TF_VERSION%
set ANDROID_DIR=android\app\src\main\jniLibs\
set ANDROID_LIB=libtensorflowlite_c.so
set ARM_DELEGATE=libtensorflowlite_c_arm_delegate.so
set ARM_64_DELEGATE=libtensorflowlite_c_arm64_delegate.so
set ARM=libtensorflowlite_c_arm.so
set ARM_64=libtensorflowlite_c_arm64.so
set X86=libtensorflowlite_c_x86_delegate.so
set X86_64=libtensorflowlite_c_x86_64_delegate.so
SET /A d = 0
:GETOPT
if /I "%1"=="-d" SET /A d = 1
SETLOCAL
if %d%==1 CALL :Download %ARM_DELEGATE% armeabi-v7a
if %d%==1 CALL :Download %ARM_64_DELEGATE% arm64-v8a
if %d%==0 CALL :Download %ARM% armeabi-v7a
if %d%==0 CALL :Download %ARM_64% arm64-v8a
CALL :Download %X86% x86
CALL :Download %X86_64% x86_64
EXIT /B %ERRORLEVEL%
:Download
curl -L -o %~1 %URL%%TAG%/%~1
mkdir %ANDROID_DIR%%~2\
move /-Y %~1 %ANDROID_DIR%%~2\%ANDROID_LIB%
EXIT /B 0
+52
View File
@@ -0,0 +1,52 @@
#!/usr/bin/env bash
cd "$(dirname "$(readlink -f "$0")")"
# Available versions
# 2.5, 2.4.1
TF_VERSION=2.5
URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/"
TAG="tf_$TF_VERSION"
ANDROID_DIR="android/app/src/main/jniLibs/"
ANDROID_LIB="libtensorflowlite_c.so"
ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so"
ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so"
ARM="libtensorflowlite_c_arm.so"
ARM_64="libtensorflowlite_c_arm64.so"
X86="libtensorflowlite_c_x86_delegate.so"
X86_64="libtensorflowlite_c_x86_64_delegate.so"
delegate=0
while getopts "d" OPTION
do
case $OPTION in
d) delegate=1;;
esac
done
download () {
wget "${URL}${TAG}/$1" -O "$1"
mkdir -p "${ANDROID_DIR}$2/"
mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}"
}
if [ ${delegate} -eq 1 ]
then
download ${ARM_DELEGATE} "armeabi-v7a"
download ${ARM_64_DELEGATE} "arm64-v8a"
else
download ${ARM} "armeabi-v7a"
download ${ARM_64} "arm64-v8a"
fi
download ${X86} "x86"
download ${X86_64} "x86_64"
View File
+1 -1
View File
@@ -1,5 +1,5 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/ui/home.dart';
import 'package:tensordex_mobile/ui/tensordex_home.dart';
import 'package:tensordex_mobile/utils/logger.dart';
Future<void> main() async {
+160
View File
@@ -0,0 +1,160 @@
import 'dart:math';
import 'dart:ui';
import 'package:collection/collection.dart';
import 'package:image/image.dart' as image_lib;
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
import '../utils/logger.dart';
import '../utils/recognition.dart';
import '../utils/stats.dart';
/// Classifier
class Classifier {
static const String MODEL_FILE_NAME = "detect.tflite";
static const String LABEL_FILE_NAME = "labelmap.txt";
/// Input size of image (height = width = 300)
static const int INPUT_SIZE = 224;
/// Result score threshold
static const double THRESHOLD = 0.5;
/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;
/// Padding the image to transform into square
// int padSize = 0;
/// Instance of Interpreter
late Interpreter _interpreter;
late TensorBuffer _outputBuffer;
late var _probabilityProcessor;
/// Labels file loaded as list
late List<String> _labels;
/// Number of results to show
static const int NUM_RESULTS = 10;
Classifier({
Interpreter? interpreter,
List<String>? labels,
}) {
loadModel(interpreter: interpreter);
loadLabels(labels: labels);
}
/// Loads interpreter from asset
void loadModel({Interpreter? interpreter}) async {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
MODEL_FILE_NAME,
options: InterpreterOptions()..threads = 4,
);
var outputTensor = _interpreter.getOutputTensor(0);
var outputShape = outputTensor.shape;
var outputType = outputTensor.type;
var inputTensor = _interpreter.getInputTensor(0);
var intputShape = inputTensor.shape;
var intputType = inputTensor.type;
_outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType);
_probabilityProcessor =
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
} catch (e) {
logger.e("Error while creating interpreter: ", e);
}
}
/// Loads labels from assets
void loadLabels({List<String>? labels}) async {
try {
_labels = labels ?? await FileUtil.loadLabels("assets/labels.txt");
} catch (e) {
logger.e("Error while loading labels: $e");
}
}
/// Pre-process the image
TensorImage? getProcessedImage(TensorImage inputImage) {
// padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
// .add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5, 127.5))
.build();
return imageProcessor?.process(inputImage);
}
/// Runs object detection on the input image
Map<String, dynamic>? predict(image_lib.Image image) {
logger.i(labels);
var predictStartTime = DateTime.now().millisecondsSinceEpoch;
if (_interpreter == null) {
logger.e("Interpreter not initialized");
return null;
}
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
// Pre-process TensorImage
var procImage = getProcessedImage(TensorImage.fromImage(image));
var preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
if (procImage != null) {
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
var inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
logger.i("Sending image to ML");
logger.i(procImage.buffer.asFloat32List());
logger.i(procImage.width);
logger.i(procImage.height);
logger.i(procImage.tensorBuffer.shape);
logger.i(procImage.tensorBuffer.isDynamic);
_interpreter.run(procImage.buffer, _outputBuffer.getBuffer());
Map<String, double> labeledProb = TensorLabel.fromList(
labels, _probabilityProcessor.process(_outputBuffer))
.getMapWithFloatValue();
final pred = getTopProbability(labeledProb);
Recognition rec = Recognition(1, pred.key, pred.value);
var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime;
return {
"recognitions": rec,
"stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime),
};
} else {
return null;
}
}
/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
/// Gets the loaded labels
List<String> get labels => _labels;
}
MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
var pq = PriorityQueue<MapEntry<String, double>>(compare);
pq.addAll(labeledProb.entries);
return pq.first;
}
int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
if (e1.value > e2.value) {
return -1;
} else if (e1.value == e2.value) {
return 0;
} else {
return 1;
}
}
+36 -1
View File
@@ -2,6 +2,9 @@ import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/tflite/classifier.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tensordex_mobile/utils/image_utils.dart';
import '../utils/logger.dart';
import '../utils/recognition.dart';
@@ -30,10 +33,13 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
/// Controller
late CameraController cameraController;
Interpreter? interp;
/// true when inference is ongoing
bool predicting = false;
late Classifier classy;
// /// Instance of [Classifier]
// Classifier classifier;
//
@@ -56,9 +62,28 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
// Camera initialization
initializeCamera();
// final gpuDelegateV2 = GpuDelegateV2(
// options: GpuDelegateOptionsV2(
// isPrecisionLossAllowed: false,
// inferencePreference: TfLiteGpuInferenceUsage.fastSingleAnswer,
// inferencePriority1: TfLiteGpuInferencePriority.minLatency,
// inferencePriority2: TfLiteGpuInferencePriority.auto,
// inferencePriority3: TfLiteGpuInferencePriority.auto,
// ));
logger.e("CREATING THE INTERPRETOR");
var interpreterOptions = InterpreterOptions();//..addDelegate(gpuDelegateV2);
interp = await Interpreter.fromAsset('efficientnet_v2s.tflite',
options: interpreterOptions);
logger.e("CREATING THE INTERPRETOR");
classy = Classifier(interpreter: interp);
logger.i(interp?.getOutputTensors());
// Create an instance of classifier to load model and labels
// classifier = Classifier();
// Initially predicting = false
predicting = false;
}
@@ -94,7 +119,7 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
@override
Widget build(BuildContext context) {
// Return empty container while the camera is not initialized
if (!cameraController.value.isInitialized || cameraController == null) {
if (!cameraController.value.isInitialized) {
return Container();
}
@@ -114,6 +139,16 @@ class _CameraViewState extends State<CameraView> with WidgetsBindingObserver {
predicting = true;
});
logger.i("RECIEVED IMAGE");
logger.i(cameraImage.format.group);
logger.i(cameraImage);
var converted = ImageUtils.convertCameraImage(cameraImage);
if (converted != null){
var result = classy.predict(converted);
logger.e("PREDICTED IMAGE");
logger.i(result);
}
// logger.i(cameraImage);
// logger.i(cameraImage.height);
// logger.i(cameraImage.width);
+33
View File
@@ -0,0 +1,33 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/ui/poke_view.dart';
import 'package:tensordex_mobile/utils/recognition.dart';
import '../utils/logger.dart';
/// [CameraView] sends each frame for inference
class ResultsView extends StatefulWidget {
/// Constructor
const ResultsView({Key? key}) : super(key: key);
void setResults(Recognition results){
logger.i("RESULTS IN THE RESULT VIEW");
}
@override
State<ResultsView> createState() => _ResultsViewState();
}
class _ResultsViewState extends State<ResultsView> {
@override
void initState() {
super.initState();
}
@override
Widget build(BuildContext context) {
return Text("data");
}
}
@@ -1,6 +1,6 @@
import 'package:flutter/material.dart';
import 'package:camera/camera.dart';
import 'package:tensordex_mobile/ui/poke_view.dart';
import 'package:tensordex_mobile/ui/results_view.dart';
import '../utils/logger.dart';
import '../utils/recognition.dart';
@@ -25,7 +25,6 @@ class TensordexHome extends StatefulWidget {
}
class _TensordexHomeState extends State<TensordexHome> {
int _counter = 0;
/// Results to draw bounding boxes
List<Recognition>? results;
@@ -38,7 +37,6 @@ class _TensordexHomeState extends State<TensordexHome> {
void _incrementCounter() {
setState(() {
_counter++;
logger.d("Counter Incremented!");
logger.w("Counter Incremented!");
logger.e("Counter Incremented!");
@@ -129,8 +127,6 @@ class _TensordexHomeState extends State<TensordexHome> {
@override
void dispose() {
// controller.dispose();
// WidgetsBinding.instance.removeObserver(this);
super.dispose();
}
@@ -158,17 +154,10 @@ class _TensordexHomeState extends State<TensordexHome> {
child: Column(
mainAxisAlignment: MainAxisAlignment.center,
children: <Widget>[
const Text(
'You have pushed the button this many times:',
),
Text(
'$_counter',
style: Theme.of(context).textTheme.headline4,
),
CameraView(
resultsCallback: resultsCallback,
statsCallback: statsCallback
),
statsCallback: statsCallback),
const ResultsView(),
],
),
),
+15 -1
View File
@@ -65,7 +65,7 @@ packages:
source: hosted
version: "1.1.0"
collection:
dependency: transitive
dependency: "direct main"
description:
name: collection
url: "https://pub.dartlang.org"
@@ -343,6 +343,20 @@ packages:
url: "https://pub.dartlang.org"
source: hosted
version: "0.9.0"
tflite_flutter_helper:
dependency: "direct main"
description:
name: tflite_flutter_helper
url: "https://pub.dartlang.org"
source: hosted
version: "0.3.1"
tuple:
dependency: transitive
description:
name: tuple
url: "https://pub.dartlang.org"
source: hosted
version: "2.0.0"
typed_data:
dependency: transitive
description:
+2
View File
@@ -39,6 +39,8 @@ dependencies:
logger: ^1.1.0
path_provider: ^2.0.11
tflite_flutter: ^0.9.0
tflite_flutter_helper: ^0.3.1
collection: ^1.16.0
dev_dependencies:
flutter_test: