From 284fa4a2f8dce3b64c8f6d12515fa6c13cb58a57 Mon Sep 17 00:00:00 2001 From: Lucas Oskorep Date: Fri, 8 Jul 2022 18:30:25 -0400 Subject: [PATCH] adding some basic linting - prepping support for multiple models being loaded in by the app. --- .gitignore | 2 + analysis_options.yaml | 4 +- android/app/src/main/AndroidManifest.xml | 2 +- lib/main.dart | 2 +- lib/tflite/classifier.dart | 41 +++++++------- lib/tflite/ml_isolate.dart | 9 ++- lib/tflite/model/configuration.dart | 16 ++++++ lib/tflite/model/constants.dart | 10 ++++ .../{data => model/outputs}/recognition.dart | 0 lib/tflite/{data => model/outputs}/stats.dart | 0 lib/widgets/poke_finder.dart | 55 +++++++++++++------ lib/widgets/results.dart | 4 +- lib/widgets/tensordex_home.dart | 8 +-- 13 files changed, 99 insertions(+), 54 deletions(-) create mode 100644 lib/tflite/model/configuration.dart create mode 100644 lib/tflite/model/constants.dart rename lib/tflite/{data => model/outputs}/recognition.dart (100%) rename lib/tflite/{data => model/outputs}/stats.dart (100%) diff --git a/.gitignore b/.gitignore index a8e938c..15f4fd9 100644 --- a/.gitignore +++ b/.gitignore @@ -45,3 +45,5 @@ app.*.map.json /android/app/debug /android/app/profile /android/app/release +/assets/mobilenetv2_gpu.tflite +/assets/mobilenetv2_gpu.tflite diff --git a/analysis_options.yaml b/analysis_options.yaml index 61b6c4d..7c78e1d 100644 --- a/analysis_options.yaml +++ b/analysis_options.yaml @@ -22,8 +22,8 @@ linter: # `// ignore_for_file: name_of_lint` syntax on the line or in the file # producing the lint. rules: - # avoid_print: false # Uncomment to disable the `avoid_print` rule - # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule + avoid_print: true # Uncomment to disable the `avoid_print` rule + prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule # Additional information about this file can be found at # https://dart.dev/guides/language/analysis-options diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml index e9a867a..ff28eeb 100644 --- a/android/app/src/main/AndroidManifest.xml +++ b/android/app/src/main/AndroidManifest.xml @@ -25,7 +25,7 @@ - _labels; int classifierCreationStart = -1; - Classifier({ - Interpreter? interpreter, + Classifier( + Interpreter interpreter, { List? labels, }) { - loadModel(interpreter: interpreter); + loadModel(interpreter); loadLabels(labels: labels); } /// Loads interpreter from asset - void loadModel({Interpreter? interpreter}) async { + void loadModel(Interpreter interpreter) async { try { - _interpreter = interpreter ?? - await Interpreter.fromAsset( - modelFileName, - options: InterpreterOptions()..threads = 8, - ); + _interpreter = interpreter; var outputTensor = _interpreter.getOutputTensor(0); var outputShape = outputTensor.shape; _outputType = outputTensor.type; var inputTensor = _interpreter.getInputTensor(0); - // var intputShape = inputTensor.shape; _inputType = inputTensor.type; _inputImage = TensorImage(_inputType); _outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType); _outputProcessor = TensorProcessorBuilder().add(NormalizeOp(0, 1)).build(); } catch (e) { - logger.e("Error while creating interpreter: ", e); + logger.e('Error while creating interpreter: ', e); } } /// Loads labels from assets void loadLabels({List? labels}) async { try { - _labels = labels ?? await FileUtil.loadLabels("assets/labels.txt"); + _labels = labels ?? await FileUtil.loadLabels('assets/labels.txt'); } catch (e) { - logger.e("Error while loading labels: $e"); + logger.e('Error while loading labels: $e'); } } /// Pre-process the image TensorImage? getProcessedImage(TensorImage? inputImage) { - // padSize = max(inputImage.height, inputImage.width); + int cropSize = min(_inputImage.height, _inputImage.width); if (inputImage != null) { imageProcessor ??= ImageProcessorBuilder() - .add(ResizeWithCropOrPadOp(224, 224)) + .add(ResizeWithCropOrPadOp(cropSize, cropSize)) .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR)) .add(NormalizeOp(0, 1)) - // .add(NormalizeOp(127.5, 127.5)) + // .add(NormalizeOp(127.5, 127.5)) // photo vs quant normalization .build(); return imageProcessor?.process(inputImage); } @@ -102,8 +99,8 @@ class Classifier { .toList(); var endTime = DateTime.now().millisecondsSinceEpoch; return { - "recognitions": predictions, - "stats": Stats( + 'recognitions': predictions, + 'stats': Stats( totalTime: endTime - preProcStart, preProcessingTime: inferenceStart - preProcStart, inferenceTime: postProcStart - inferenceStart, diff --git a/lib/tflite/ml_isolate.dart b/lib/tflite/ml_isolate.dart index 4e9314f..791759d 100644 --- a/lib/tflite/ml_isolate.dart +++ b/lib/tflite/ml_isolate.dart @@ -11,7 +11,7 @@ class IsolateBase { } class MLIsolate extends IsolateBase { - static const String debugIsolate = "MLIsolate"; + static const String debugIsolate = 'MLIsolate'; late SendPort _sendPort; SendPort get sendPort => _sendPort; @@ -34,19 +34,18 @@ class MLIsolate extends IsolateBase { var converted = ImageUtils.convertCameraImage(cameraImage); if (converted != null) { Classifier classifier = Classifier( - interpreter: - Interpreter.fromAddress(mlIsolateData.interpreterAddress), + Interpreter.fromAddress(mlIsolateData.interpreterAddress), labels: mlIsolateData.labels); var result = classifier.predict(converted); mlIsolateData.responsePort?.send(result); } else { - mlIsolateData.responsePort?.send({"response": "not working yet"}); + mlIsolateData.responsePort?.send({'response': 'not working yet'}); } } } } -/// Bundles data to pass between Isolate +/// Bundles outputs to pass between Isolate class MLIsolateData { CameraImage cameraImage; int interpreterAddress; diff --git a/lib/tflite/model/configuration.dart b/lib/tflite/model/configuration.dart new file mode 100644 index 0000000..365ba8c --- /dev/null +++ b/lib/tflite/model/configuration.dart @@ -0,0 +1,16 @@ +import 'package:tflite_flutter/tflite_flutter.dart'; +import 'constants.dart'; + +class ModelConfiguration{ + String name; + late List interpreters; + + ModelConfiguration(this.name){ + interpreters = name.contains('gpu') ? ModelConstants.gpuInterpreterList : ModelConstants.cpuInterpreterList; + } + + @override + String toString() { + return 'ModelConfiguration(name: $name, interpreters: $interpreters)'; + } +} \ No newline at end of file diff --git a/lib/tflite/model/constants.dart b/lib/tflite/model/constants.dart new file mode 100644 index 0000000..072e8e1 --- /dev/null +++ b/lib/tflite/model/constants.dart @@ -0,0 +1,10 @@ +import 'package:tflite_flutter/tflite_flutter.dart'; + + +class ModelConstants { + static final InterpreterOptions _npuConfig = InterpreterOptions()..threads = 8..useNnApiForAndroid = true..useMetalDelegateForIOS = true; + static final InterpreterOptions _cpuConfig = InterpreterOptions()..threads = 8; + static final List gpuInterpreterList = [_npuConfig, _cpuConfig]; + static final List cpuInterpreterList = [_cpuConfig]; +} + diff --git a/lib/tflite/data/recognition.dart b/lib/tflite/model/outputs/recognition.dart similarity index 100% rename from lib/tflite/data/recognition.dart rename to lib/tflite/model/outputs/recognition.dart diff --git a/lib/tflite/data/stats.dart b/lib/tflite/model/outputs/stats.dart similarity index 100% rename from lib/tflite/data/stats.dart rename to lib/tflite/model/outputs/stats.dart diff --git a/lib/widgets/poke_finder.dart b/lib/widgets/poke_finder.dart index e453131..dbbcca3 100644 --- a/lib/widgets/poke_finder.dart +++ b/lib/widgets/poke_finder.dart @@ -1,16 +1,19 @@ +import 'dart:convert'; import 'dart:isolate'; import 'package:camera/camera.dart'; import 'package:flutter/material.dart'; +import 'package:flutter/services.dart'; import 'package:tensordex_mobile/tflite/ml_isolate.dart'; +import 'package:tensordex_mobile/tflite/model/configuration.dart'; +import 'package:tensordex_mobile/tflite/model/outputs/stats.dart'; import 'package:tflite_flutter/tflite_flutter.dart'; import '../tflite/classifier.dart'; +import '../tflite/model/outputs/recognition.dart'; import '../utils/logger.dart'; -import '../tflite/data/recognition.dart'; -import '../tflite/data/stats.dart'; -/// [PokeFinder] sends each frame for inference + class PokeFinder extends StatefulWidget { /// Callback to pass results after inference to [HomeView] final Function(List recognitions) resultsCallback; @@ -28,17 +31,20 @@ class PokeFinder extends StatefulWidget { } class _PokeFinderState extends State with WidgetsBindingObserver { - late List cameras; - late CameraController cameraController; - late MLIsolate _mlIsolate; - /// true when inference is ongoing bool predicting = false; bool _cameraInitialized = false; bool _classifierInitialized = false; + //cameras + late List cameras; + late CameraController cameraController; + + //ml variables late Interpreter interpreter; late Classifier classifier; + late MLIsolate _mlIsolate; + late List modelConfigurations; @override void initState() { @@ -55,19 +61,34 @@ class _PokeFinderState extends State with WidgetsBindingObserver { predicting = false; } + Future> getModelFiles() async { + final manifestContent = await rootBundle.loadString('AssetManifest.jsn'); + final Map manifestMap = json.decode(manifestContent); + return manifestMap.keys + .where((String key) => key.contains('.tflite')) + .map((String key) => key.substring(7)) + .toList(); + } + void initializeModel() async { - var interpreterOptions = InterpreterOptions()..threads = 8; - interpreter = await Interpreter.fromAsset('efficientnet_v2s.tflite', - options: interpreterOptions); - classifier = Classifier(interpreter: interpreter); + var modelFiles = await getModelFiles(); + var modelConfigurations = + modelFiles.map((e) => ModelConfiguration(e)).toList(); + var currentConfig = modelConfigurations[0]; + logger.i(modelFiles); + interpreter = await createInterpreter(currentConfig); + classifier = Classifier(interpreter); _classifierInitialized = true; } + Future createInterpreter(ModelConfiguration config) async { + return await Interpreter.fromAsset(config.name, + options: config.interpreters[0]); + } + /// Initializes the camera by setting [cameraController] void initializeCamera() async { cameras = await availableCameras(); - - // cameras[0] for rear-camera cameraController = CameraController(cameras[0], ResolutionPreset.low, enableAudio: false); @@ -94,11 +115,11 @@ class _PokeFinderState extends State with WidgetsBindingObserver { var results = await inference(MLIsolateData( cameraImage, classifier.interpreter.address, classifier.labels)); - if (results.containsKey("recognitions")) { - widget.resultsCallback(results["recognitions"]); + if (results.containsKey('recognitions')) { + widget.resultsCallback(results['recognitions']); } - if (results.containsKey("stats")) { - widget.statsCallback(results["stats"]); + if (results.containsKey('stats')) { + widget.statsCallback(results['stats']); } logger.i(results); diff --git a/lib/widgets/results.dart b/lib/widgets/results.dart index c812b16..5e2aae6 100644 --- a/lib/widgets/results.dart +++ b/lib/widgets/results.dart @@ -1,7 +1,7 @@ import 'package:flutter/material.dart'; import 'package:tensordex_mobile/widgets/poke_finder.dart'; -import 'package:tensordex_mobile/tflite/data/recognition.dart'; -import 'package:tensordex_mobile/tflite/data/stats.dart'; +import '../tflite/model/outputs/recognition.dart'; +import '../tflite/model/outputs/stats.dart'; /// [PokeFinder] sends each frame for inference diff --git a/lib/widgets/tensordex_home.dart b/lib/widgets/tensordex_home.dart index fc9b9e5..e65cc99 100644 --- a/lib/widgets/tensordex_home.dart +++ b/lib/widgets/tensordex_home.dart @@ -1,10 +1,10 @@ import 'package:flutter/material.dart'; +import 'package:tensordex_mobile/tflite/model/outputs/recognition.dart'; +import 'package:tensordex_mobile/tflite/model/outputs/stats.dart'; import 'package:tensordex_mobile/widgets/poke_finder.dart'; import 'package:tensordex_mobile/widgets/results.dart'; import '../utils/logger.dart'; -import '../tflite/data/recognition.dart'; -import '../tflite/data/stats.dart'; class TensordexHome extends StatefulWidget { const TensordexHome({Key? key, required this.title}) : super(key: key); @@ -22,7 +22,7 @@ class TensordexHome extends StatefulWidget { class _TensordexHomeState extends State { /// Results from the image classifier - List results = [Recognition(1, "NOTHING DETECTED", .5)]; + List results = [Recognition(1, 'NOTHING DETECTED', .5)]; Stats stats = Stats(); /// Scaffold Key @@ -30,7 +30,7 @@ class _TensordexHomeState extends State { void _incrementCounter() { setState(() { - logger.d("Counter Incremented!"); + logger.d('Counter Incremented!'); }); }