adding some basic linting - prepping support for multiple models being loaded in by the app.

This commit is contained in:
Lucas Oskorep
2022-07-08 18:30:25 -04:00
parent b8119e6520
commit 284fa4a2f8
13 changed files with 99 additions and 54 deletions
+2
View File
@@ -45,3 +45,5 @@ app.*.map.json
/android/app/debug /android/app/debug
/android/app/profile /android/app/profile
/android/app/release /android/app/release
/assets/mobilenetv2_gpu.tflite
/assets/mobilenetv2_gpu.tflite
+2 -2
View File
@@ -22,8 +22,8 @@ linter:
# `// ignore_for_file: name_of_lint` syntax on the line or in the file # `// ignore_for_file: name_of_lint` syntax on the line or in the file
# producing the lint. # producing the lint.
rules: rules:
# avoid_print: false # Uncomment to disable the `avoid_print` rule avoid_print: true # Uncomment to disable the `avoid_print` rule
# prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule
# Additional information about this file can be found at # Additional information about this file can be found at
# https://dart.dev/guides/language/analysis-options # https://dart.dev/guides/language/analysis-options
+1 -1
View File
@@ -25,7 +25,7 @@
<category android:name="android.intent.category.LAUNCHER"/> <category android:name="android.intent.category.LAUNCHER"/>
</intent-filter> </intent-filter>
</activity> </activity>
<!-- Don't delete the meta-data below. <!-- Don't delete the meta-outputs below.
This is used by the Flutter tool to generate GeneratedPluginRegistrant.java --> This is used by the Flutter tool to generate GeneratedPluginRegistrant.java -->
<meta-data <meta-data
android:name="flutterEmbedding" android:name="flutterEmbedding"
+1 -1
View File
@@ -12,7 +12,7 @@ class MyApp extends StatelessWidget {
// This widget is the root of your application. // This widget is the root of your application.
@override @override
Widget build(BuildContext context) { Widget build(BuildContext context) {
logger.i("Building main app"); logger.i('Building main app');
return MaterialApp( return MaterialApp(
title: 'Tensordex', title: 'Tensordex',
theme: ThemeData( theme: ThemeData(
+19 -22
View File
@@ -1,21 +1,23 @@
import 'dart:math';
import 'package:collection/collection.dart'; import 'package:collection/collection.dart';
import 'package:image/image.dart' as image_lib; import 'package:image/image.dart' as image_lib;
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
import 'model/outputs/recognition.dart';
import '../utils/logger.dart'; import '../utils/logger.dart';
import 'data/recognition.dart'; import 'model/outputs/stats.dart';
import 'data/stats.dart';
/// Classifier /// Classifier
class Classifier { class Classifier {
static const String modelFileName = "efficientnet_v2s.tflite"; static const String modelFileName = 'efficientnet_v2s.tflite';
static const int inputSize = 224; static const int inputSize = 224;
/// [ImageProcessor] used to pre-process the image /// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor; ImageProcessor? imageProcessor;
///Tensor image to move image data into ///Tensor image to move image outputs into
late TensorImage _inputImage; late TensorImage _inputImage;
/// Instance of Interpreter /// Instance of Interpreter
@@ -30,55 +32,50 @@ class Classifier {
late List<String> _labels; late List<String> _labels;
int classifierCreationStart = -1; int classifierCreationStart = -1;
Classifier({ Classifier(
Interpreter? interpreter, Interpreter interpreter, {
List<String>? labels, List<String>? labels,
}) { }) {
loadModel(interpreter: interpreter); loadModel(interpreter);
loadLabels(labels: labels); loadLabels(labels: labels);
} }
/// Loads interpreter from asset /// Loads interpreter from asset
void loadModel({Interpreter? interpreter}) async { void loadModel(Interpreter interpreter) async {
try { try {
_interpreter = interpreter ?? _interpreter = interpreter;
await Interpreter.fromAsset(
modelFileName,
options: InterpreterOptions()..threads = 8,
);
var outputTensor = _interpreter.getOutputTensor(0); var outputTensor = _interpreter.getOutputTensor(0);
var outputShape = outputTensor.shape; var outputShape = outputTensor.shape;
_outputType = outputTensor.type; _outputType = outputTensor.type;
var inputTensor = _interpreter.getInputTensor(0); var inputTensor = _interpreter.getInputTensor(0);
// var intputShape = inputTensor.shape;
_inputType = inputTensor.type; _inputType = inputTensor.type;
_inputImage = TensorImage(_inputType); _inputImage = TensorImage(_inputType);
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType); _outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
_outputProcessor = _outputProcessor =
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build(); TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
} catch (e) { } catch (e) {
logger.e("Error while creating interpreter: ", e); logger.e('Error while creating interpreter: ', e);
} }
} }
/// Loads labels from assets /// Loads labels from assets
void loadLabels({List<String>? labels}) async { void loadLabels({List<String>? labels}) async {
try { try {
_labels = labels ?? await FileUtil.loadLabels("assets/labels.txt"); _labels = labels ?? await FileUtil.loadLabels('assets/labels.txt');
} catch (e) { } catch (e) {
logger.e("Error while loading labels: $e"); logger.e('Error while loading labels: $e');
} }
} }
/// Pre-process the image /// Pre-process the image
TensorImage? getProcessedImage(TensorImage? inputImage) { TensorImage? getProcessedImage(TensorImage? inputImage) {
// padSize = max(inputImage.height, inputImage.width); int cropSize = min(_inputImage.height, _inputImage.width);
if (inputImage != null) { if (inputImage != null) {
imageProcessor ??= ImageProcessorBuilder() imageProcessor ??= ImageProcessorBuilder()
.add(ResizeWithCropOrPadOp(224, 224)) .add(ResizeWithCropOrPadOp(cropSize, cropSize))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR)) .add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
.add(NormalizeOp(0, 1)) .add(NormalizeOp(0, 1))
// .add(NormalizeOp(127.5, 127.5)) // .add(NormalizeOp(127.5, 127.5)) // photo vs quant normalization
.build(); .build();
return imageProcessor?.process(inputImage); return imageProcessor?.process(inputImage);
} }
@@ -102,8 +99,8 @@ class Classifier {
.toList(); .toList();
var endTime = DateTime.now().millisecondsSinceEpoch; var endTime = DateTime.now().millisecondsSinceEpoch;
return { return {
"recognitions": predictions, 'recognitions': predictions,
"stats": Stats( 'stats': Stats(
totalTime: endTime - preProcStart, totalTime: endTime - preProcStart,
preProcessingTime: inferenceStart - preProcStart, preProcessingTime: inferenceStart - preProcStart,
inferenceTime: postProcStart - inferenceStart, inferenceTime: postProcStart - inferenceStart,
+4 -5
View File
@@ -11,7 +11,7 @@ class IsolateBase {
} }
class MLIsolate extends IsolateBase { class MLIsolate extends IsolateBase {
static const String debugIsolate = "MLIsolate"; static const String debugIsolate = 'MLIsolate';
late SendPort _sendPort; late SendPort _sendPort;
SendPort get sendPort => _sendPort; SendPort get sendPort => _sendPort;
@@ -34,19 +34,18 @@ class MLIsolate extends IsolateBase {
var converted = ImageUtils.convertCameraImage(cameraImage); var converted = ImageUtils.convertCameraImage(cameraImage);
if (converted != null) { if (converted != null) {
Classifier classifier = Classifier( Classifier classifier = Classifier(
interpreter: Interpreter.fromAddress(mlIsolateData.interpreterAddress),
Interpreter.fromAddress(mlIsolateData.interpreterAddress),
labels: mlIsolateData.labels); labels: mlIsolateData.labels);
var result = classifier.predict(converted); var result = classifier.predict(converted);
mlIsolateData.responsePort?.send(result); mlIsolateData.responsePort?.send(result);
} else { } else {
mlIsolateData.responsePort?.send({"response": "not working yet"}); mlIsolateData.responsePort?.send({'response': 'not working yet'});
} }
} }
} }
} }
/// Bundles data to pass between Isolate /// Bundles outputs to pass between Isolate
class MLIsolateData { class MLIsolateData {
CameraImage cameraImage; CameraImage cameraImage;
int interpreterAddress; int interpreterAddress;
+16
View File
@@ -0,0 +1,16 @@
import 'package:tflite_flutter/tflite_flutter.dart';
import 'constants.dart';
class ModelConfiguration{
String name;
late List<InterpreterOptions> interpreters;
ModelConfiguration(this.name){
interpreters = name.contains('gpu') ? ModelConstants.gpuInterpreterList : ModelConstants.cpuInterpreterList;
}
@override
String toString() {
return 'ModelConfiguration(name: $name, interpreters: $interpreters)';
}
}
+10
View File
@@ -0,0 +1,10 @@
import 'package:tflite_flutter/tflite_flutter.dart';
class ModelConstants {
static final InterpreterOptions _npuConfig = InterpreterOptions()..threads = 8..useNnApiForAndroid = true..useMetalDelegateForIOS = true;
static final InterpreterOptions _cpuConfig = InterpreterOptions()..threads = 8;
static final List<InterpreterOptions> gpuInterpreterList = [_npuConfig, _cpuConfig];
static final List<InterpreterOptions> cpuInterpreterList = [_cpuConfig];
}
+38 -17
View File
@@ -1,16 +1,19 @@
import 'dart:convert';
import 'dart:isolate'; import 'dart:isolate';
import 'package:camera/camera.dart'; import 'package:camera/camera.dart';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:tensordex_mobile/tflite/ml_isolate.dart'; import 'package:tensordex_mobile/tflite/ml_isolate.dart';
import 'package:tensordex_mobile/tflite/model/configuration.dart';
import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter/tflite_flutter.dart';
import '../tflite/classifier.dart'; import '../tflite/classifier.dart';
import '../tflite/model/outputs/recognition.dart';
import '../utils/logger.dart'; import '../utils/logger.dart';
import '../tflite/data/recognition.dart';
import '../tflite/data/stats.dart';
/// [PokeFinder] sends each frame for inference
class PokeFinder extends StatefulWidget { class PokeFinder extends StatefulWidget {
/// Callback to pass results after inference to [HomeView] /// Callback to pass results after inference to [HomeView]
final Function(List<Recognition> recognitions) resultsCallback; final Function(List<Recognition> recognitions) resultsCallback;
@@ -28,17 +31,20 @@ class PokeFinder extends StatefulWidget {
} }
class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver { class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
late List<CameraDescription> cameras;
late CameraController cameraController;
late MLIsolate _mlIsolate;
/// true when inference is ongoing /// true when inference is ongoing
bool predicting = false; bool predicting = false;
bool _cameraInitialized = false; bool _cameraInitialized = false;
bool _classifierInitialized = false; bool _classifierInitialized = false;
//cameras
late List<CameraDescription> cameras;
late CameraController cameraController;
//ml variables
late Interpreter interpreter; late Interpreter interpreter;
late Classifier classifier; late Classifier classifier;
late MLIsolate _mlIsolate;
late List<ModelConfiguration> modelConfigurations;
@override @override
void initState() { void initState() {
@@ -55,19 +61,34 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
predicting = false; predicting = false;
} }
Future<List<String>> getModelFiles() async {
final manifestContent = await rootBundle.loadString('AssetManifest.jsn');
final Map<String, dynamic> manifestMap = json.decode(manifestContent);
return manifestMap.keys
.where((String key) => key.contains('.tflite'))
.map((String key) => key.substring(7))
.toList();
}
void initializeModel() async { void initializeModel() async {
var interpreterOptions = InterpreterOptions()..threads = 8; var modelFiles = await getModelFiles();
interpreter = await Interpreter.fromAsset('efficientnet_v2s.tflite', var modelConfigurations =
options: interpreterOptions); modelFiles.map((e) => ModelConfiguration(e)).toList();
classifier = Classifier(interpreter: interpreter); var currentConfig = modelConfigurations[0];
logger.i(modelFiles);
interpreter = await createInterpreter(currentConfig);
classifier = Classifier(interpreter);
_classifierInitialized = true; _classifierInitialized = true;
} }
Future<Interpreter> createInterpreter(ModelConfiguration config) async {
return await Interpreter.fromAsset(config.name,
options: config.interpreters[0]);
}
/// Initializes the camera by setting [cameraController] /// Initializes the camera by setting [cameraController]
void initializeCamera() async { void initializeCamera() async {
cameras = await availableCameras(); cameras = await availableCameras();
// cameras[0] for rear-camera
cameraController = cameraController =
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false); CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
@@ -94,11 +115,11 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
var results = await inference(MLIsolateData( var results = await inference(MLIsolateData(
cameraImage, classifier.interpreter.address, classifier.labels)); cameraImage, classifier.interpreter.address, classifier.labels));
if (results.containsKey("recognitions")) { if (results.containsKey('recognitions')) {
widget.resultsCallback(results["recognitions"]); widget.resultsCallback(results['recognitions']);
} }
if (results.containsKey("stats")) { if (results.containsKey('stats')) {
widget.statsCallback(results["stats"]); widget.statsCallback(results['stats']);
} }
logger.i(results); logger.i(results);
+2 -2
View File
@@ -1,7 +1,7 @@
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart'; import 'package:tensordex_mobile/widgets/poke_finder.dart';
import 'package:tensordex_mobile/tflite/data/recognition.dart'; import '../tflite/model/outputs/recognition.dart';
import 'package:tensordex_mobile/tflite/data/stats.dart'; import '../tflite/model/outputs/stats.dart';
/// [PokeFinder] sends each frame for inference /// [PokeFinder] sends each frame for inference
+4 -4
View File
@@ -1,10 +1,10 @@
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:tensordex_mobile/tflite/model/outputs/recognition.dart';
import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart'; import 'package:tensordex_mobile/widgets/poke_finder.dart';
import 'package:tensordex_mobile/widgets/results.dart'; import 'package:tensordex_mobile/widgets/results.dart';
import '../utils/logger.dart'; import '../utils/logger.dart';
import '../tflite/data/recognition.dart';
import '../tflite/data/stats.dart';
class TensordexHome extends StatefulWidget { class TensordexHome extends StatefulWidget {
const TensordexHome({Key? key, required this.title}) : super(key: key); const TensordexHome({Key? key, required this.title}) : super(key: key);
@@ -22,7 +22,7 @@ class TensordexHome extends StatefulWidget {
class _TensordexHomeState extends State<TensordexHome> { class _TensordexHomeState extends State<TensordexHome> {
/// Results from the image classifier /// Results from the image classifier
List<Recognition> results = [Recognition(1, "NOTHING DETECTED", .5)]; List<Recognition> results = [Recognition(1, 'NOTHING DETECTED', .5)];
Stats stats = Stats(); Stats stats = Stats();
/// Scaffold Key /// Scaffold Key
@@ -30,7 +30,7 @@ class _TensordexHomeState extends State<TensordexHome> {
void _incrementCounter() { void _incrementCounter() {
setState(() { setState(() {
logger.d("Counter Incremented!"); logger.d('Counter Incremented!');
}); });
} }