adding some basic linting - prepping support for multiple models being loaded in by the app.
This commit is contained in:
@@ -45,3 +45,5 @@ app.*.map.json
|
|||||||
/android/app/debug
|
/android/app/debug
|
||||||
/android/app/profile
|
/android/app/profile
|
||||||
/android/app/release
|
/android/app/release
|
||||||
|
/assets/mobilenetv2_gpu.tflite
|
||||||
|
/assets/mobilenetv2_gpu.tflite
|
||||||
|
|||||||
@@ -22,8 +22,8 @@ linter:
|
|||||||
# `// ignore_for_file: name_of_lint` syntax on the line or in the file
|
# `// ignore_for_file: name_of_lint` syntax on the line or in the file
|
||||||
# producing the lint.
|
# producing the lint.
|
||||||
rules:
|
rules:
|
||||||
# avoid_print: false # Uncomment to disable the `avoid_print` rule
|
avoid_print: true # Uncomment to disable the `avoid_print` rule
|
||||||
# prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule
|
prefer_single_quotes: true # Uncomment to enable the `prefer_single_quotes` rule
|
||||||
|
|
||||||
# Additional information about this file can be found at
|
# Additional information about this file can be found at
|
||||||
# https://dart.dev/guides/language/analysis-options
|
# https://dart.dev/guides/language/analysis-options
|
||||||
|
|||||||
@@ -25,7 +25,7 @@
|
|||||||
<category android:name="android.intent.category.LAUNCHER"/>
|
<category android:name="android.intent.category.LAUNCHER"/>
|
||||||
</intent-filter>
|
</intent-filter>
|
||||||
</activity>
|
</activity>
|
||||||
<!-- Don't delete the meta-data below.
|
<!-- Don't delete the meta-outputs below.
|
||||||
This is used by the Flutter tool to generate GeneratedPluginRegistrant.java -->
|
This is used by the Flutter tool to generate GeneratedPluginRegistrant.java -->
|
||||||
<meta-data
|
<meta-data
|
||||||
android:name="flutterEmbedding"
|
android:name="flutterEmbedding"
|
||||||
|
|||||||
+1
-1
@@ -12,7 +12,7 @@ class MyApp extends StatelessWidget {
|
|||||||
// This widget is the root of your application.
|
// This widget is the root of your application.
|
||||||
@override
|
@override
|
||||||
Widget build(BuildContext context) {
|
Widget build(BuildContext context) {
|
||||||
logger.i("Building main app");
|
logger.i('Building main app');
|
||||||
return MaterialApp(
|
return MaterialApp(
|
||||||
title: 'Tensordex',
|
title: 'Tensordex',
|
||||||
theme: ThemeData(
|
theme: ThemeData(
|
||||||
|
|||||||
+19
-22
@@ -1,21 +1,23 @@
|
|||||||
|
import 'dart:math';
|
||||||
|
|
||||||
import 'package:collection/collection.dart';
|
import 'package:collection/collection.dart';
|
||||||
import 'package:image/image.dart' as image_lib;
|
import 'package:image/image.dart' as image_lib;
|
||||||
import 'package:tflite_flutter/tflite_flutter.dart';
|
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||||
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
|
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
|
||||||
|
|
||||||
|
import 'model/outputs/recognition.dart';
|
||||||
import '../utils/logger.dart';
|
import '../utils/logger.dart';
|
||||||
import 'data/recognition.dart';
|
import 'model/outputs/stats.dart';
|
||||||
import 'data/stats.dart';
|
|
||||||
|
|
||||||
/// Classifier
|
/// Classifier
|
||||||
class Classifier {
|
class Classifier {
|
||||||
static const String modelFileName = "efficientnet_v2s.tflite";
|
static const String modelFileName = 'efficientnet_v2s.tflite';
|
||||||
static const int inputSize = 224;
|
static const int inputSize = 224;
|
||||||
|
|
||||||
/// [ImageProcessor] used to pre-process the image
|
/// [ImageProcessor] used to pre-process the image
|
||||||
ImageProcessor? imageProcessor;
|
ImageProcessor? imageProcessor;
|
||||||
|
|
||||||
///Tensor image to move image data into
|
///Tensor image to move image outputs into
|
||||||
late TensorImage _inputImage;
|
late TensorImage _inputImage;
|
||||||
|
|
||||||
/// Instance of Interpreter
|
/// Instance of Interpreter
|
||||||
@@ -30,55 +32,50 @@ class Classifier {
|
|||||||
late List<String> _labels;
|
late List<String> _labels;
|
||||||
int classifierCreationStart = -1;
|
int classifierCreationStart = -1;
|
||||||
|
|
||||||
Classifier({
|
Classifier(
|
||||||
Interpreter? interpreter,
|
Interpreter interpreter, {
|
||||||
List<String>? labels,
|
List<String>? labels,
|
||||||
}) {
|
}) {
|
||||||
loadModel(interpreter: interpreter);
|
loadModel(interpreter);
|
||||||
loadLabels(labels: labels);
|
loadLabels(labels: labels);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads interpreter from asset
|
/// Loads interpreter from asset
|
||||||
void loadModel({Interpreter? interpreter}) async {
|
void loadModel(Interpreter interpreter) async {
|
||||||
try {
|
try {
|
||||||
_interpreter = interpreter ??
|
_interpreter = interpreter;
|
||||||
await Interpreter.fromAsset(
|
|
||||||
modelFileName,
|
|
||||||
options: InterpreterOptions()..threads = 8,
|
|
||||||
);
|
|
||||||
var outputTensor = _interpreter.getOutputTensor(0);
|
var outputTensor = _interpreter.getOutputTensor(0);
|
||||||
var outputShape = outputTensor.shape;
|
var outputShape = outputTensor.shape;
|
||||||
_outputType = outputTensor.type;
|
_outputType = outputTensor.type;
|
||||||
var inputTensor = _interpreter.getInputTensor(0);
|
var inputTensor = _interpreter.getInputTensor(0);
|
||||||
// var intputShape = inputTensor.shape;
|
|
||||||
_inputType = inputTensor.type;
|
_inputType = inputTensor.type;
|
||||||
_inputImage = TensorImage(_inputType);
|
_inputImage = TensorImage(_inputType);
|
||||||
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
|
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
|
||||||
_outputProcessor =
|
_outputProcessor =
|
||||||
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
|
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
logger.e("Error while creating interpreter: ", e);
|
logger.e('Error while creating interpreter: ', e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Loads labels from assets
|
/// Loads labels from assets
|
||||||
void loadLabels({List<String>? labels}) async {
|
void loadLabels({List<String>? labels}) async {
|
||||||
try {
|
try {
|
||||||
_labels = labels ?? await FileUtil.loadLabels("assets/labels.txt");
|
_labels = labels ?? await FileUtil.loadLabels('assets/labels.txt');
|
||||||
} catch (e) {
|
} catch (e) {
|
||||||
logger.e("Error while loading labels: $e");
|
logger.e('Error while loading labels: $e');
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pre-process the image
|
/// Pre-process the image
|
||||||
TensorImage? getProcessedImage(TensorImage? inputImage) {
|
TensorImage? getProcessedImage(TensorImage? inputImage) {
|
||||||
// padSize = max(inputImage.height, inputImage.width);
|
int cropSize = min(_inputImage.height, _inputImage.width);
|
||||||
if (inputImage != null) {
|
if (inputImage != null) {
|
||||||
imageProcessor ??= ImageProcessorBuilder()
|
imageProcessor ??= ImageProcessorBuilder()
|
||||||
.add(ResizeWithCropOrPadOp(224, 224))
|
.add(ResizeWithCropOrPadOp(cropSize, cropSize))
|
||||||
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
||||||
.add(NormalizeOp(0, 1))
|
.add(NormalizeOp(0, 1))
|
||||||
// .add(NormalizeOp(127.5, 127.5))
|
// .add(NormalizeOp(127.5, 127.5)) // photo vs quant normalization
|
||||||
.build();
|
.build();
|
||||||
return imageProcessor?.process(inputImage);
|
return imageProcessor?.process(inputImage);
|
||||||
}
|
}
|
||||||
@@ -102,8 +99,8 @@ class Classifier {
|
|||||||
.toList();
|
.toList();
|
||||||
var endTime = DateTime.now().millisecondsSinceEpoch;
|
var endTime = DateTime.now().millisecondsSinceEpoch;
|
||||||
return {
|
return {
|
||||||
"recognitions": predictions,
|
'recognitions': predictions,
|
||||||
"stats": Stats(
|
'stats': Stats(
|
||||||
totalTime: endTime - preProcStart,
|
totalTime: endTime - preProcStart,
|
||||||
preProcessingTime: inferenceStart - preProcStart,
|
preProcessingTime: inferenceStart - preProcStart,
|
||||||
inferenceTime: postProcStart - inferenceStart,
|
inferenceTime: postProcStart - inferenceStart,
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ class IsolateBase {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class MLIsolate extends IsolateBase {
|
class MLIsolate extends IsolateBase {
|
||||||
static const String debugIsolate = "MLIsolate";
|
static const String debugIsolate = 'MLIsolate';
|
||||||
late SendPort _sendPort;
|
late SendPort _sendPort;
|
||||||
|
|
||||||
SendPort get sendPort => _sendPort;
|
SendPort get sendPort => _sendPort;
|
||||||
@@ -34,19 +34,18 @@ class MLIsolate extends IsolateBase {
|
|||||||
var converted = ImageUtils.convertCameraImage(cameraImage);
|
var converted = ImageUtils.convertCameraImage(cameraImage);
|
||||||
if (converted != null) {
|
if (converted != null) {
|
||||||
Classifier classifier = Classifier(
|
Classifier classifier = Classifier(
|
||||||
interpreter:
|
Interpreter.fromAddress(mlIsolateData.interpreterAddress),
|
||||||
Interpreter.fromAddress(mlIsolateData.interpreterAddress),
|
|
||||||
labels: mlIsolateData.labels);
|
labels: mlIsolateData.labels);
|
||||||
var result = classifier.predict(converted);
|
var result = classifier.predict(converted);
|
||||||
mlIsolateData.responsePort?.send(result);
|
mlIsolateData.responsePort?.send(result);
|
||||||
} else {
|
} else {
|
||||||
mlIsolateData.responsePort?.send({"response": "not working yet"});
|
mlIsolateData.responsePort?.send({'response': 'not working yet'});
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Bundles data to pass between Isolate
|
/// Bundles outputs to pass between Isolate
|
||||||
class MLIsolateData {
|
class MLIsolateData {
|
||||||
CameraImage cameraImage;
|
CameraImage cameraImage;
|
||||||
int interpreterAddress;
|
int interpreterAddress;
|
||||||
|
|||||||
@@ -0,0 +1,16 @@
|
|||||||
|
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||||
|
import 'constants.dart';
|
||||||
|
|
||||||
|
class ModelConfiguration{
|
||||||
|
String name;
|
||||||
|
late List<InterpreterOptions> interpreters;
|
||||||
|
|
||||||
|
ModelConfiguration(this.name){
|
||||||
|
interpreters = name.contains('gpu') ? ModelConstants.gpuInterpreterList : ModelConstants.cpuInterpreterList;
|
||||||
|
}
|
||||||
|
|
||||||
|
@override
|
||||||
|
String toString() {
|
||||||
|
return 'ModelConfiguration(name: $name, interpreters: $interpreters)';
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConstants {
|
||||||
|
static final InterpreterOptions _npuConfig = InterpreterOptions()..threads = 8..useNnApiForAndroid = true..useMetalDelegateForIOS = true;
|
||||||
|
static final InterpreterOptions _cpuConfig = InterpreterOptions()..threads = 8;
|
||||||
|
static final List<InterpreterOptions> gpuInterpreterList = [_npuConfig, _cpuConfig];
|
||||||
|
static final List<InterpreterOptions> cpuInterpreterList = [_cpuConfig];
|
||||||
|
}
|
||||||
|
|
||||||
@@ -1,16 +1,19 @@
|
|||||||
|
import 'dart:convert';
|
||||||
import 'dart:isolate';
|
import 'dart:isolate';
|
||||||
|
|
||||||
import 'package:camera/camera.dart';
|
import 'package:camera/camera.dart';
|
||||||
import 'package:flutter/material.dart';
|
import 'package:flutter/material.dart';
|
||||||
|
import 'package:flutter/services.dart';
|
||||||
import 'package:tensordex_mobile/tflite/ml_isolate.dart';
|
import 'package:tensordex_mobile/tflite/ml_isolate.dart';
|
||||||
|
import 'package:tensordex_mobile/tflite/model/configuration.dart';
|
||||||
|
import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
|
||||||
import 'package:tflite_flutter/tflite_flutter.dart';
|
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||||
|
|
||||||
import '../tflite/classifier.dart';
|
import '../tflite/classifier.dart';
|
||||||
|
import '../tflite/model/outputs/recognition.dart';
|
||||||
import '../utils/logger.dart';
|
import '../utils/logger.dart';
|
||||||
import '../tflite/data/recognition.dart';
|
|
||||||
import '../tflite/data/stats.dart';
|
|
||||||
|
|
||||||
/// [PokeFinder] sends each frame for inference
|
|
||||||
class PokeFinder extends StatefulWidget {
|
class PokeFinder extends StatefulWidget {
|
||||||
/// Callback to pass results after inference to [HomeView]
|
/// Callback to pass results after inference to [HomeView]
|
||||||
final Function(List<Recognition> recognitions) resultsCallback;
|
final Function(List<Recognition> recognitions) resultsCallback;
|
||||||
@@ -28,17 +31,20 @@ class PokeFinder extends StatefulWidget {
|
|||||||
}
|
}
|
||||||
|
|
||||||
class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
|
class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
|
||||||
late List<CameraDescription> cameras;
|
|
||||||
late CameraController cameraController;
|
|
||||||
late MLIsolate _mlIsolate;
|
|
||||||
|
|
||||||
/// true when inference is ongoing
|
/// true when inference is ongoing
|
||||||
bool predicting = false;
|
bool predicting = false;
|
||||||
bool _cameraInitialized = false;
|
bool _cameraInitialized = false;
|
||||||
bool _classifierInitialized = false;
|
bool _classifierInitialized = false;
|
||||||
|
|
||||||
|
//cameras
|
||||||
|
late List<CameraDescription> cameras;
|
||||||
|
late CameraController cameraController;
|
||||||
|
|
||||||
|
//ml variables
|
||||||
late Interpreter interpreter;
|
late Interpreter interpreter;
|
||||||
late Classifier classifier;
|
late Classifier classifier;
|
||||||
|
late MLIsolate _mlIsolate;
|
||||||
|
late List<ModelConfiguration> modelConfigurations;
|
||||||
|
|
||||||
@override
|
@override
|
||||||
void initState() {
|
void initState() {
|
||||||
@@ -55,19 +61,34 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
|
|||||||
predicting = false;
|
predicting = false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<List<String>> getModelFiles() async {
|
||||||
|
final manifestContent = await rootBundle.loadString('AssetManifest.jsn');
|
||||||
|
final Map<String, dynamic> manifestMap = json.decode(manifestContent);
|
||||||
|
return manifestMap.keys
|
||||||
|
.where((String key) => key.contains('.tflite'))
|
||||||
|
.map((String key) => key.substring(7))
|
||||||
|
.toList();
|
||||||
|
}
|
||||||
|
|
||||||
void initializeModel() async {
|
void initializeModel() async {
|
||||||
var interpreterOptions = InterpreterOptions()..threads = 8;
|
var modelFiles = await getModelFiles();
|
||||||
interpreter = await Interpreter.fromAsset('efficientnet_v2s.tflite',
|
var modelConfigurations =
|
||||||
options: interpreterOptions);
|
modelFiles.map((e) => ModelConfiguration(e)).toList();
|
||||||
classifier = Classifier(interpreter: interpreter);
|
var currentConfig = modelConfigurations[0];
|
||||||
|
logger.i(modelFiles);
|
||||||
|
interpreter = await createInterpreter(currentConfig);
|
||||||
|
classifier = Classifier(interpreter);
|
||||||
_classifierInitialized = true;
|
_classifierInitialized = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Future<Interpreter> createInterpreter(ModelConfiguration config) async {
|
||||||
|
return await Interpreter.fromAsset(config.name,
|
||||||
|
options: config.interpreters[0]);
|
||||||
|
}
|
||||||
|
|
||||||
/// Initializes the camera by setting [cameraController]
|
/// Initializes the camera by setting [cameraController]
|
||||||
void initializeCamera() async {
|
void initializeCamera() async {
|
||||||
cameras = await availableCameras();
|
cameras = await availableCameras();
|
||||||
|
|
||||||
// cameras[0] for rear-camera
|
|
||||||
cameraController =
|
cameraController =
|
||||||
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
|
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
|
||||||
|
|
||||||
@@ -94,11 +115,11 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
|
|||||||
var results = await inference(MLIsolateData(
|
var results = await inference(MLIsolateData(
|
||||||
cameraImage, classifier.interpreter.address, classifier.labels));
|
cameraImage, classifier.interpreter.address, classifier.labels));
|
||||||
|
|
||||||
if (results.containsKey("recognitions")) {
|
if (results.containsKey('recognitions')) {
|
||||||
widget.resultsCallback(results["recognitions"]);
|
widget.resultsCallback(results['recognitions']);
|
||||||
}
|
}
|
||||||
if (results.containsKey("stats")) {
|
if (results.containsKey('stats')) {
|
||||||
widget.statsCallback(results["stats"]);
|
widget.statsCallback(results['stats']);
|
||||||
}
|
}
|
||||||
logger.i(results);
|
logger.i(results);
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import 'package:flutter/material.dart';
|
import 'package:flutter/material.dart';
|
||||||
import 'package:tensordex_mobile/widgets/poke_finder.dart';
|
import 'package:tensordex_mobile/widgets/poke_finder.dart';
|
||||||
import 'package:tensordex_mobile/tflite/data/recognition.dart';
|
import '../tflite/model/outputs/recognition.dart';
|
||||||
import 'package:tensordex_mobile/tflite/data/stats.dart';
|
import '../tflite/model/outputs/stats.dart';
|
||||||
|
|
||||||
|
|
||||||
/// [PokeFinder] sends each frame for inference
|
/// [PokeFinder] sends each frame for inference
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import 'package:flutter/material.dart';
|
import 'package:flutter/material.dart';
|
||||||
|
import 'package:tensordex_mobile/tflite/model/outputs/recognition.dart';
|
||||||
|
import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
|
||||||
import 'package:tensordex_mobile/widgets/poke_finder.dart';
|
import 'package:tensordex_mobile/widgets/poke_finder.dart';
|
||||||
import 'package:tensordex_mobile/widgets/results.dart';
|
import 'package:tensordex_mobile/widgets/results.dart';
|
||||||
|
|
||||||
import '../utils/logger.dart';
|
import '../utils/logger.dart';
|
||||||
import '../tflite/data/recognition.dart';
|
|
||||||
import '../tflite/data/stats.dart';
|
|
||||||
|
|
||||||
class TensordexHome extends StatefulWidget {
|
class TensordexHome extends StatefulWidget {
|
||||||
const TensordexHome({Key? key, required this.title}) : super(key: key);
|
const TensordexHome({Key? key, required this.title}) : super(key: key);
|
||||||
@@ -22,7 +22,7 @@ class TensordexHome extends StatefulWidget {
|
|||||||
|
|
||||||
class _TensordexHomeState extends State<TensordexHome> {
|
class _TensordexHomeState extends State<TensordexHome> {
|
||||||
/// Results from the image classifier
|
/// Results from the image classifier
|
||||||
List<Recognition> results = [Recognition(1, "NOTHING DETECTED", .5)];
|
List<Recognition> results = [Recognition(1, 'NOTHING DETECTED', .5)];
|
||||||
Stats stats = Stats();
|
Stats stats = Stats();
|
||||||
|
|
||||||
/// Scaffold Key
|
/// Scaffold Key
|
||||||
@@ -30,7 +30,7 @@ class _TensordexHomeState extends State<TensordexHome> {
|
|||||||
|
|
||||||
void _incrementCounter() {
|
void _incrementCounter() {
|
||||||
setState(() {
|
setState(() {
|
||||||
logger.d("Counter Incremented!");
|
logger.d('Counter Incremented!');
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user