diff --git a/install_tflite.bat b/install_tflite.bat new file mode 100644 index 0000000..8100952 --- /dev/null +++ b/install_tflite.bat @@ -0,0 +1,38 @@ +@echo off +setlocal enableextensions + +cd %~dp0 + +set TF_VERSION=2.5 +set URL=https://github.com/am15h/tflite_flutter_plugin/releases/download/ +set TAG=tf_%TF_VERSION% + +set ANDROID_DIR=android\app\src\main\jniLibs\ +set ANDROID_LIB=libtensorflowlite_c.so + +set ARM_DELEGATE=libtensorflowlite_c_arm_delegate.so +set ARM_64_DELEGATE=libtensorflowlite_c_arm64_delegate.so +set ARM=libtensorflowlite_c_arm.so +set ARM_64=libtensorflowlite_c_arm64.so +set X86=libtensorflowlite_c_x86_delegate.so +set X86_64=libtensorflowlite_c_x86_64_delegate.so + +SET /A d = 0 + +:GETOPT +if /I "%1"=="-d" SET /A d = 1 + +SETLOCAL +if %d%==1 CALL :Download %ARM_DELEGATE% armeabi-v7a +if %d%==1 CALL :Download %ARM_64_DELEGATE% arm64-v8a +if %d%==0 CALL :Download %ARM% armeabi-v7a +if %d%==0 CALL :Download %ARM_64% arm64-v8a +CALL :Download %X86% x86 +CALL :Download %X86_64% x86_64 +EXIT /B %ERRORLEVEL% + +:Download +curl -L -o %~1 %URL%%TAG%/%~1 +mkdir %ANDROID_DIR%%~2\ +move /-Y %~1 %ANDROID_DIR%%~2\%ANDROID_LIB% +EXIT /B 0 diff --git a/install_tflite.sh b/install_tflite.sh new file mode 100644 index 0000000..4185ac0 --- /dev/null +++ b/install_tflite.sh @@ -0,0 +1,52 @@ +#!/usr/bin/env bash + +cd "$(dirname "$(readlink -f "$0")")" + +# Available versions +# 2.5, 2.4.1 + +TF_VERSION=2.5 +URL="https://github.com/am15h/tflite_flutter_plugin/releases/download/" +TAG="tf_$TF_VERSION" + +ANDROID_DIR="android/app/src/main/jniLibs/" +ANDROID_LIB="libtensorflowlite_c.so" + +ARM_DELEGATE="libtensorflowlite_c_arm_delegate.so" +ARM_64_DELEGATE="libtensorflowlite_c_arm64_delegate.so" +ARM="libtensorflowlite_c_arm.so" +ARM_64="libtensorflowlite_c_arm64.so" +X86="libtensorflowlite_c_x86_delegate.so" +X86_64="libtensorflowlite_c_x86_64_delegate.so" + +delegate=0 + +while getopts "d" OPTION +do + case $OPTION in + d) delegate=1;; + esac +done + + +download () { + wget "${URL}${TAG}/$1" -O "$1" + mkdir -p "${ANDROID_DIR}$2/" + mv $1 "${ANDROID_DIR}$2/${ANDROID_LIB}" +} + +if [ ${delegate} -eq 1 ] +then + +download ${ARM_DELEGATE} "armeabi-v7a" +download ${ARM_64_DELEGATE} "arm64-v8a" + +else + +download ${ARM} "armeabi-v7a" +download ${ARM_64} "arm64-v8a" + +fi + +download ${X86} "x86" +download ${X86_64} "x86_64" diff --git a/lib/classifier.dart b/lib/classifier.dart deleted file mode 100644 index e69de29..0000000 diff --git a/lib/main.dart b/lib/main.dart index dc3ea3d..47b1a09 100644 --- a/lib/main.dart +++ b/lib/main.dart @@ -1,5 +1,5 @@ import 'package:flutter/material.dart'; -import 'package:tensordex_mobile/ui/home.dart'; +import 'package:tensordex_mobile/ui/tensordex_home.dart'; import 'package:tensordex_mobile/utils/logger.dart'; Future main() async { diff --git a/lib/tflite/classifier.dart b/lib/tflite/classifier.dart new file mode 100644 index 0000000..fc65b16 --- /dev/null +++ b/lib/tflite/classifier.dart @@ -0,0 +1,160 @@ +import 'dart:math'; +import 'dart:ui'; + +import 'package:collection/collection.dart'; +import 'package:image/image.dart' as image_lib; +import 'package:tflite_flutter/tflite_flutter.dart'; +import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; + +import '../utils/logger.dart'; +import '../utils/recognition.dart'; +import '../utils/stats.dart'; + +/// Classifier +class Classifier { + static const String MODEL_FILE_NAME = "detect.tflite"; + static const String LABEL_FILE_NAME = "labelmap.txt"; + + /// Input size of image (height = width = 300) + static const int INPUT_SIZE = 224; + + /// Result score threshold + static const double THRESHOLD = 0.5; + + /// [ImageProcessor] used to pre-process the image + ImageProcessor? imageProcessor; + + /// Padding the image to transform into square + // int padSize = 0; + /// Instance of Interpreter + late Interpreter _interpreter; + + late TensorBuffer _outputBuffer; + late var _probabilityProcessor; + + /// Labels file loaded as list + late List _labels; + + /// Number of results to show + static const int NUM_RESULTS = 10; + + Classifier({ + Interpreter? interpreter, + List? labels, + }) { + loadModel(interpreter: interpreter); + loadLabels(labels: labels); + } + + /// Loads interpreter from asset + void loadModel({Interpreter? interpreter}) async { + try { + _interpreter = interpreter ?? + await Interpreter.fromAsset( + MODEL_FILE_NAME, + options: InterpreterOptions()..threads = 4, + ); + var outputTensor = _interpreter.getOutputTensor(0); + var outputShape = outputTensor.shape; + var outputType = outputTensor.type; + + var inputTensor = _interpreter.getInputTensor(0); + var intputShape = inputTensor.shape; + var intputType = inputTensor.type; + + _outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType); + _probabilityProcessor = + TensorProcessorBuilder().add(NormalizeOp(0, 1)).build(); + } catch (e) { + logger.e("Error while creating interpreter: ", e); + } + } + + /// Loads labels from assets + void loadLabels({List? labels}) async { + try { + _labels = labels ?? await FileUtil.loadLabels("assets/labels.txt"); + } catch (e) { + logger.e("Error while loading labels: $e"); + } + } + + /// Pre-process the image + TensorImage? getProcessedImage(TensorImage inputImage) { + // padSize = max(inputImage.height, inputImage.width); + imageProcessor ??= ImageProcessorBuilder() + // .add(ResizeWithCropOrPadOp(padSize, padSize)) + .add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR)) + .add(NormalizeOp(127.5, 127.5)) + .build(); + return imageProcessor?.process(inputImage); + } + + /// Runs object detection on the input image + Map? predict(image_lib.Image image) { + logger.i(labels); + var predictStartTime = DateTime.now().millisecondsSinceEpoch; + if (_interpreter == null) { + logger.e("Interpreter not initialized"); + return null; + } + var preProcessStart = DateTime.now().millisecondsSinceEpoch; + // Create TensorImage from image + // Pre-process TensorImage + var procImage = getProcessedImage(TensorImage.fromImage(image)); + + var preProcessElapsedTime = + DateTime.now().millisecondsSinceEpoch - preProcessStart; + if (procImage != null) { + var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch; + // run inference + var inferenceTimeElapsed = + DateTime.now().millisecondsSinceEpoch - inferenceTimeStart; + + logger.i("Sending image to ML"); + + logger.i(procImage.buffer.asFloat32List()); + logger.i(procImage.width); + logger.i(procImage.height); + logger.i(procImage.tensorBuffer.shape); + logger.i(procImage.tensorBuffer.isDynamic); + _interpreter.run(procImage.buffer, _outputBuffer.getBuffer()); + + Map labeledProb = TensorLabel.fromList( + labels, _probabilityProcessor.process(_outputBuffer)) + .getMapWithFloatValue(); + final pred = getTopProbability(labeledProb); + Recognition rec = Recognition(1, pred.key, pred.value); + var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime; + return { + "recognitions": rec, + "stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime), + }; + } else { + return null; + } + } + + /// Gets the interpreter instance + Interpreter get interpreter => _interpreter; + + /// Gets the loaded labels + List get labels => _labels; +} + +MapEntry getTopProbability(Map labeledProb) { + var pq = PriorityQueue>(compare); + pq.addAll(labeledProb.entries); + + return pq.first; +} + +int compare(MapEntry e1, MapEntry e2) { + if (e1.value > e2.value) { + return -1; + } else if (e1.value == e2.value) { + return 0; + } else { + return 1; + } +} diff --git a/lib/ui/poke_view.dart b/lib/ui/poke_view.dart index b32ff75..72d6845 100644 --- a/lib/ui/poke_view.dart +++ b/lib/ui/poke_view.dart @@ -2,6 +2,9 @@ import 'dart:isolate'; import 'package:camera/camera.dart'; import 'package:flutter/material.dart'; +import 'package:tensordex_mobile/tflite/classifier.dart'; +import 'package:tflite_flutter/tflite_flutter.dart'; +import 'package:tensordex_mobile/utils/image_utils.dart'; import '../utils/logger.dart'; import '../utils/recognition.dart'; @@ -30,10 +33,13 @@ class _CameraViewState extends State with WidgetsBindingObserver { /// Controller late CameraController cameraController; + Interpreter? interp; /// true when inference is ongoing bool predicting = false; + late Classifier classy; + // /// Instance of [Classifier] // Classifier classifier; // @@ -56,9 +62,28 @@ class _CameraViewState extends State with WidgetsBindingObserver { // Camera initialization initializeCamera(); + // final gpuDelegateV2 = GpuDelegateV2( + // options: GpuDelegateOptionsV2( + // isPrecisionLossAllowed: false, + // inferencePreference: TfLiteGpuInferenceUsage.fastSingleAnswer, + // inferencePriority1: TfLiteGpuInferencePriority.minLatency, + // inferencePriority2: TfLiteGpuInferencePriority.auto, + // inferencePriority3: TfLiteGpuInferencePriority.auto, + // )); + + + logger.e("CREATING THE INTERPRETOR"); + var interpreterOptions = InterpreterOptions();//..addDelegate(gpuDelegateV2); + interp = await Interpreter.fromAsset('efficientnet_v2s.tflite', + options: interpreterOptions); + logger.e("CREATING THE INTERPRETOR"); + + classy = Classifier(interpreter: interp); + logger.i(interp?.getOutputTensors()); // Create an instance of classifier to load model and labels // classifier = Classifier(); + // Initially predicting = false predicting = false; } @@ -94,7 +119,7 @@ class _CameraViewState extends State with WidgetsBindingObserver { @override Widget build(BuildContext context) { // Return empty container while the camera is not initialized - if (!cameraController.value.isInitialized || cameraController == null) { + if (!cameraController.value.isInitialized) { return Container(); } @@ -114,6 +139,16 @@ class _CameraViewState extends State with WidgetsBindingObserver { predicting = true; }); logger.i("RECIEVED IMAGE"); + logger.i(cameraImage.format.group); + logger.i(cameraImage); + var converted = ImageUtils.convertCameraImage(cameraImage); + if (converted != null){ + + var result = classy.predict(converted); + + logger.e("PREDICTED IMAGE"); + logger.i(result); + } // logger.i(cameraImage); // logger.i(cameraImage.height); // logger.i(cameraImage.width); diff --git a/lib/ui/results_view.dart b/lib/ui/results_view.dart index e69de29..4c0af40 100644 --- a/lib/ui/results_view.dart +++ b/lib/ui/results_view.dart @@ -0,0 +1,33 @@ +import 'package:flutter/material.dart'; +import 'package:tensordex_mobile/ui/poke_view.dart'; +import 'package:tensordex_mobile/utils/recognition.dart'; + +import '../utils/logger.dart'; + +/// [CameraView] sends each frame for inference +class ResultsView extends StatefulWidget { + + /// Constructor + const ResultsView({Key? key}) : super(key: key); + + + void setResults(Recognition results){ + logger.i("RESULTS IN THE RESULT VIEW"); + } + + @override + State createState() => _ResultsViewState(); +} + +class _ResultsViewState extends State { + + @override + void initState() { + super.initState(); + } + + @override + Widget build(BuildContext context) { + return Text("data"); + } +} \ No newline at end of file diff --git a/lib/ui/home.dart b/lib/ui/tensordex_home.dart similarity index 91% rename from lib/ui/home.dart rename to lib/ui/tensordex_home.dart index e2ee1c8..27a0096 100644 --- a/lib/ui/home.dart +++ b/lib/ui/tensordex_home.dart @@ -1,6 +1,6 @@ import 'package:flutter/material.dart'; -import 'package:camera/camera.dart'; import 'package:tensordex_mobile/ui/poke_view.dart'; +import 'package:tensordex_mobile/ui/results_view.dart'; import '../utils/logger.dart'; import '../utils/recognition.dart'; @@ -25,7 +25,6 @@ class TensordexHome extends StatefulWidget { } class _TensordexHomeState extends State { - int _counter = 0; /// Results to draw bounding boxes List? results; @@ -38,7 +37,6 @@ class _TensordexHomeState extends State { void _incrementCounter() { setState(() { - _counter++; logger.d("Counter Incremented!"); logger.w("Counter Incremented!"); logger.e("Counter Incremented!"); @@ -129,8 +127,6 @@ class _TensordexHomeState extends State { @override void dispose() { - // controller.dispose(); - // WidgetsBinding.instance.removeObserver(this); super.dispose(); } @@ -158,17 +154,10 @@ class _TensordexHomeState extends State { child: Column( mainAxisAlignment: MainAxisAlignment.center, children: [ - const Text( - 'You have pushed the button this many times:', - ), - Text( - '$_counter', - style: Theme.of(context).textTheme.headline4, - ), CameraView( resultsCallback: resultsCallback, - statsCallback: statsCallback - ), + statsCallback: statsCallback), + const ResultsView(), ], ), ), diff --git a/pubspec.lock b/pubspec.lock index 12fa155..ef2a499 100644 --- a/pubspec.lock +++ b/pubspec.lock @@ -65,7 +65,7 @@ packages: source: hosted version: "1.1.0" collection: - dependency: transitive + dependency: "direct main" description: name: collection url: "https://pub.dartlang.org" @@ -343,6 +343,20 @@ packages: url: "https://pub.dartlang.org" source: hosted version: "0.9.0" + tflite_flutter_helper: + dependency: "direct main" + description: + name: tflite_flutter_helper + url: "https://pub.dartlang.org" + source: hosted + version: "0.3.1" + tuple: + dependency: transitive + description: + name: tuple + url: "https://pub.dartlang.org" + source: hosted + version: "2.0.0" typed_data: dependency: transitive description: diff --git a/pubspec.yaml b/pubspec.yaml index 6cdd924..b14d8ed 100644 --- a/pubspec.yaml +++ b/pubspec.yaml @@ -39,6 +39,8 @@ dependencies: logger: ^1.1.0 path_provider: ^2.0.11 tflite_flutter: ^0.9.0 + tflite_flutter_helper: ^0.3.1 + collection: ^1.16.0 dev_dependencies: flutter_test: