diff --git a/.gitignore b/.gitignore
index a8e938c..15f4fd9 100644
--- a/.gitignore
+++ b/.gitignore
@@ -45,3 +45,5 @@ app.*.map.json
/android/app/debug
/android/app/profile
/android/app/release
+/assets/mobilenetv2_gpu.tflite
+/assets/mobilenetv2_gpu.tflite
diff --git a/analysis_options.yaml b/analysis_options.yaml
index 61b6c4d..7c78e1d 100644
--- a/analysis_options.yaml
+++ b/analysis_options.yaml
@@ -22,8 +22,8 @@ linter:
# `// ignore_for_file: name_of_lint` syntax on the line or in the file
# producing the lint.
rules:
- # avoid_print: false # Uncomment to disable the `avoid_print` rule
- # prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule
+ avoid_print: true # Uncomment to disable the `avoid_print` rule
+ prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule
# Additional information about this file can be found at
# https://dart.dev/guides/language/analysis-options
diff --git a/android/app/src/main/AndroidManifest.xml b/android/app/src/main/AndroidManifest.xml
index e9a867a..ff28eeb 100644
--- a/android/app/src/main/AndroidManifest.xml
+++ b/android/app/src/main/AndroidManifest.xml
@@ -25,7 +25,7 @@
-
_labels;
int classifierCreationStart = -1;
- Classifier({
- Interpreter? interpreter,
+ Classifier(
+ Interpreter interpreter, {
List? labels,
}) {
- loadModel(interpreter: interpreter);
+ loadModel(interpreter);
loadLabels(labels: labels);
}
/// Loads interpreter from asset
- void loadModel({Interpreter? interpreter}) async {
+ void loadModel(Interpreter interpreter) async {
try {
- _interpreter = interpreter ??
- await Interpreter.fromAsset(
- modelFileName,
- options: InterpreterOptions()..threads = 8,
- );
+ _interpreter = interpreter;
var outputTensor = _interpreter.getOutputTensor(0);
var outputShape = outputTensor.shape;
_outputType = outputTensor.type;
var inputTensor = _interpreter.getInputTensor(0);
- // var intputShape = inputTensor.shape;
_inputType = inputTensor.type;
_inputImage = TensorImage(_inputType);
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
_outputProcessor =
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
} catch (e) {
- logger.e("Error while creating interpreter: ", e);
+ logger.e('Error while creating interpreter: ', e);
}
}
/// Loads labels from assets
void loadLabels({List? labels}) async {
try {
- _labels = labels ?? await FileUtil.loadLabels("assets/labels.txt");
+ _labels = labels ?? await FileUtil.loadLabels('assets/labels.txt');
} catch (e) {
- logger.e("Error while loading labels: $e");
+ logger.e('Error while loading labels: $e');
}
}
/// Pre-process the image
TensorImage? getProcessedImage(TensorImage? inputImage) {
- // padSize = max(inputImage.height, inputImage.width);
+ int cropSize = min(_inputImage.height, _inputImage.width);
if (inputImage != null) {
imageProcessor ??= ImageProcessorBuilder()
- .add(ResizeWithCropOrPadOp(224, 224))
+ .add(ResizeWithCropOrPadOp(cropSize, cropSize))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
.add(NormalizeOp(0, 1))
- // .add(NormalizeOp(127.5, 127.5))
+ // .add(NormalizeOp(127.5, 127.5)) // photo vs quant normalization
.build();
return imageProcessor?.process(inputImage);
}
@@ -102,8 +99,8 @@ class Classifier {
.toList();
var endTime = DateTime.now().millisecondsSinceEpoch;
return {
- "recognitions": predictions,
- "stats": Stats(
+ 'recognitions': predictions,
+ 'stats': Stats(
totalTime: endTime - preProcStart,
preProcessingTime: inferenceStart - preProcStart,
inferenceTime: postProcStart - inferenceStart,
diff --git a/lib/tflite/ml_isolate.dart b/lib/tflite/ml_isolate.dart
index 4e9314f..791759d 100644
--- a/lib/tflite/ml_isolate.dart
+++ b/lib/tflite/ml_isolate.dart
@@ -11,7 +11,7 @@ class IsolateBase {
}
class MLIsolate extends IsolateBase {
- static const String debugIsolate = "MLIsolate";
+ static const String debugIsolate = 'MLIsolate';
late SendPort _sendPort;
SendPort get sendPort => _sendPort;
@@ -34,19 +34,18 @@ class MLIsolate extends IsolateBase {
var converted = ImageUtils.convertCameraImage(cameraImage);
if (converted != null) {
Classifier classifier = Classifier(
- interpreter:
- Interpreter.fromAddress(mlIsolateData.interpreterAddress),
+ Interpreter.fromAddress(mlIsolateData.interpreterAddress),
labels: mlIsolateData.labels);
var result = classifier.predict(converted);
mlIsolateData.responsePort?.send(result);
} else {
- mlIsolateData.responsePort?.send({"response": "not working yet"});
+ mlIsolateData.responsePort?.send({'response': 'not working yet'});
}
}
}
}
-/// Bundles data to pass between Isolate
+/// Bundles outputs to pass between Isolate
class MLIsolateData {
CameraImage cameraImage;
int interpreterAddress;
diff --git a/lib/tflite/model/configuration.dart b/lib/tflite/model/configuration.dart
new file mode 100644
index 0000000..365ba8c
--- /dev/null
+++ b/lib/tflite/model/configuration.dart
@@ -0,0 +1,16 @@
+import 'package:tflite_flutter/tflite_flutter.dart';
+import 'constants.dart';
+
+class ModelConfiguration{
+ String name;
+ late List interpreters;
+
+ ModelConfiguration(this.name){
+ interpreters = name.contains('gpu') ? ModelConstants.gpuInterpreterList : ModelConstants.cpuInterpreterList;
+ }
+
+ @override
+ String toString() {
+ return 'ModelConfiguration(name: $name, interpreters: $interpreters)';
+ }
+}
\ No newline at end of file
diff --git a/lib/tflite/model/constants.dart b/lib/tflite/model/constants.dart
new file mode 100644
index 0000000..072e8e1
--- /dev/null
+++ b/lib/tflite/model/constants.dart
@@ -0,0 +1,10 @@
+import 'package:tflite_flutter/tflite_flutter.dart';
+
+
+class ModelConstants {
+ static final InterpreterOptions _npuConfig = InterpreterOptions()..threads = 8..useNnApiForAndroid = true..useMetalDelegateForIOS = true;
+ static final InterpreterOptions _cpuConfig = InterpreterOptions()..threads = 8;
+ static final List gpuInterpreterList = [_npuConfig, _cpuConfig];
+ static final List cpuInterpreterList = [_cpuConfig];
+}
+
diff --git a/lib/tflite/data/recognition.dart b/lib/tflite/model/outputs/recognition.dart
similarity index 100%
rename from lib/tflite/data/recognition.dart
rename to lib/tflite/model/outputs/recognition.dart
diff --git a/lib/tflite/data/stats.dart b/lib/tflite/model/outputs/stats.dart
similarity index 100%
rename from lib/tflite/data/stats.dart
rename to lib/tflite/model/outputs/stats.dart
diff --git a/lib/widgets/poke_finder.dart b/lib/widgets/poke_finder.dart
index e453131..dbbcca3 100644
--- a/lib/widgets/poke_finder.dart
+++ b/lib/widgets/poke_finder.dart
@@ -1,16 +1,19 @@
+import 'dart:convert';
import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
+import 'package:flutter/services.dart';
import 'package:tensordex_mobile/tflite/ml_isolate.dart';
+import 'package:tensordex_mobile/tflite/model/configuration.dart';
+import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import '../tflite/classifier.dart';
+import '../tflite/model/outputs/recognition.dart';
import '../utils/logger.dart';
-import '../tflite/data/recognition.dart';
-import '../tflite/data/stats.dart';
-/// [PokeFinder] sends each frame for inference
+
class PokeFinder extends StatefulWidget {
/// Callback to pass results after inference to [HomeView]
final Function(List recognitions) resultsCallback;
@@ -28,17 +31,20 @@ class PokeFinder extends StatefulWidget {
}
class _PokeFinderState extends State with WidgetsBindingObserver {
- late List cameras;
- late CameraController cameraController;
- late MLIsolate _mlIsolate;
-
/// true when inference is ongoing
bool predicting = false;
bool _cameraInitialized = false;
bool _classifierInitialized = false;
+ //cameras
+ late List cameras;
+ late CameraController cameraController;
+
+ //ml variables
late Interpreter interpreter;
late Classifier classifier;
+ late MLIsolate _mlIsolate;
+ late List modelConfigurations;
@override
void initState() {
@@ -55,19 +61,34 @@ class _PokeFinderState extends State with WidgetsBindingObserver {
predicting = false;
}
+ Future> getModelFiles() async {
+ final manifestContent = await rootBundle.loadString('AssetManifest.jsn');
+ final Map manifestMap = json.decode(manifestContent);
+ return manifestMap.keys
+ .where((String key) => key.contains('.tflite'))
+ .map((String key) => key.substring(7))
+ .toList();
+ }
+
void initializeModel() async {
- var interpreterOptions = InterpreterOptions()..threads = 8;
- interpreter = await Interpreter.fromAsset('efficientnet_v2s.tflite',
- options: interpreterOptions);
- classifier = Classifier(interpreter: interpreter);
+ var modelFiles = await getModelFiles();
+ var modelConfigurations =
+ modelFiles.map((e) => ModelConfiguration(e)).toList();
+ var currentConfig = modelConfigurations[0];
+ logger.i(modelFiles);
+ interpreter = await createInterpreter(currentConfig);
+ classifier = Classifier(interpreter);
_classifierInitialized = true;
}
+ Future createInterpreter(ModelConfiguration config) async {
+ return await Interpreter.fromAsset(config.name,
+ options: config.interpreters[0]);
+ }
+
/// Initializes the camera by setting [cameraController]
void initializeCamera() async {
cameras = await availableCameras();
-
- // cameras[0] for rear-camera
cameraController =
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
@@ -94,11 +115,11 @@ class _PokeFinderState extends State with WidgetsBindingObserver {
var results = await inference(MLIsolateData(
cameraImage, classifier.interpreter.address, classifier.labels));
- if (results.containsKey("recognitions")) {
- widget.resultsCallback(results["recognitions"]);
+ if (results.containsKey('recognitions')) {
+ widget.resultsCallback(results['recognitions']);
}
- if (results.containsKey("stats")) {
- widget.statsCallback(results["stats"]);
+ if (results.containsKey('stats')) {
+ widget.statsCallback(results['stats']);
}
logger.i(results);
diff --git a/lib/widgets/results.dart b/lib/widgets/results.dart
index c812b16..5e2aae6 100644
--- a/lib/widgets/results.dart
+++ b/lib/widgets/results.dart
@@ -1,7 +1,7 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart';
-import 'package:tensordex_mobile/tflite/data/recognition.dart';
-import 'package:tensordex_mobile/tflite/data/stats.dart';
+import '../tflite/model/outputs/recognition.dart';
+import '../tflite/model/outputs/stats.dart';
/// [PokeFinder] sends each frame for inference
diff --git a/lib/widgets/tensordex_home.dart b/lib/widgets/tensordex_home.dart
index fc9b9e5..e65cc99 100644
--- a/lib/widgets/tensordex_home.dart
+++ b/lib/widgets/tensordex_home.dart
@@ -1,10 +1,10 @@
import 'package:flutter/material.dart';
+import 'package:tensordex_mobile/tflite/model/outputs/recognition.dart';
+import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart';
import 'package:tensordex_mobile/widgets/results.dart';
import '../utils/logger.dart';
-import '../tflite/data/recognition.dart';
-import '../tflite/data/stats.dart';
class TensordexHome extends StatefulWidget {
const TensordexHome({Key? key, required this.title}) : super(key: key);
@@ -22,7 +22,7 @@ class TensordexHome extends StatefulWidget {
class _TensordexHomeState extends State {
/// Results from the image classifier
- List results = [Recognition(1, "NOTHING DETECTED", .5)];
+ List results = [Recognition(1, 'NOTHING DETECTED', .5)];
Stats stats = Stats();
/// Scaffold Key
@@ -30,7 +30,7 @@ class _TensordexHomeState extends State {
void _incrementCounter() {
setState(() {
- logger.d("Counter Incremented!");
+ logger.d('Counter Incremented!');
});
}