fixed classifier and added in a preliminary results view that shows what pokemon are currently being looked at.
This commit is contained in:
+57
-80
@@ -1,42 +1,35 @@
|
||||
import 'dart:math';
|
||||
import 'dart:ui';
|
||||
|
||||
import 'package:collection/collection.dart';
|
||||
import 'package:image/image.dart' as image_lib;
|
||||
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
|
||||
|
||||
import '../utils/logger.dart';
|
||||
import '../utils/recognition.dart';
|
||||
import '../utils/stats.dart';
|
||||
import 'data/recognition.dart';
|
||||
import 'data/stats.dart';
|
||||
|
||||
/// Classifier
|
||||
class Classifier {
|
||||
static const String MODEL_FILE_NAME = "detect.tflite";
|
||||
static const String LABEL_FILE_NAME = "labelmap.txt";
|
||||
|
||||
/// Input size of image (height = width = 300)
|
||||
static const int INPUT_SIZE = 224;
|
||||
|
||||
/// Result score threshold
|
||||
static const double THRESHOLD = 0.5;
|
||||
static const String modelFileName = "efficientnet_v2s.tflite";
|
||||
static const int inputSize = 224;
|
||||
|
||||
/// [ImageProcessor] used to pre-process the image
|
||||
ImageProcessor? imageProcessor;
|
||||
|
||||
/// Padding the image to transform into square
|
||||
// int padSize = 0;
|
||||
///Tensor image to move image data into
|
||||
late TensorImage _inputImage;
|
||||
|
||||
/// Instance of Interpreter
|
||||
late Interpreter _interpreter;
|
||||
|
||||
late TensorBuffer _outputBuffer;
|
||||
late var _probabilityProcessor;
|
||||
late TfLiteType _inputType;
|
||||
late TfLiteType _outputType;
|
||||
|
||||
late SequentialProcessor<TensorBuffer> _outputProcessor;
|
||||
|
||||
/// Labels file loaded as list
|
||||
late List<String> _labels;
|
||||
int classifierCreationStart = -1;
|
||||
|
||||
/// Number of results to show
|
||||
static const int NUM_RESULTS = 10;
|
||||
|
||||
Classifier({
|
||||
Interpreter? interpreter,
|
||||
@@ -51,19 +44,18 @@ class Classifier {
|
||||
try {
|
||||
_interpreter = interpreter ??
|
||||
await Interpreter.fromAsset(
|
||||
MODEL_FILE_NAME,
|
||||
options: InterpreterOptions()..threads = 4,
|
||||
modelFileName,
|
||||
options: InterpreterOptions()..threads = 8,
|
||||
);
|
||||
var outputTensor = _interpreter.getOutputTensor(0);
|
||||
var outputShape = outputTensor.shape;
|
||||
var outputType = outputTensor.type;
|
||||
|
||||
_outputType = outputTensor.type;
|
||||
var inputTensor = _interpreter.getInputTensor(0);
|
||||
var intputShape = inputTensor.shape;
|
||||
var intputType = inputTensor.type;
|
||||
|
||||
_outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType);
|
||||
_probabilityProcessor =
|
||||
// var intputShape = inputTensor.shape;
|
||||
_inputType = inputTensor.type;
|
||||
_inputImage = TensorImage(_inputType);
|
||||
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
|
||||
_outputProcessor =
|
||||
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
|
||||
} catch (e) {
|
||||
logger.e("Error while creating interpreter: ", e);
|
||||
@@ -80,61 +72,45 @@ class Classifier {
|
||||
}
|
||||
|
||||
/// Pre-process the image
|
||||
TensorImage? getProcessedImage(TensorImage inputImage) {
|
||||
TensorImage? getProcessedImage(TensorImage? inputImage) {
|
||||
// padSize = max(inputImage.height, inputImage.width);
|
||||
imageProcessor ??= ImageProcessorBuilder()
|
||||
// .add(ResizeWithCropOrPadOp(padSize, padSize))
|
||||
.add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR))
|
||||
.add(NormalizeOp(127.5, 127.5))
|
||||
.build();
|
||||
return imageProcessor?.process(inputImage);
|
||||
if (inputImage != null) {
|
||||
imageProcessor ??= ImageProcessorBuilder()
|
||||
.add(ResizeWithCropOrPadOp(224, 224))
|
||||
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
|
||||
// .add(NormalizeOp(127.5, 127.5))
|
||||
.build();
|
||||
return imageProcessor?.process(inputImage);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
/// Runs object detection on the input image
|
||||
Map<String, dynamic>? predict(image_lib.Image image) {
|
||||
logger.i(labels);
|
||||
var predictStartTime = DateTime.now().millisecondsSinceEpoch;
|
||||
if (_interpreter == null) {
|
||||
logger.e("Interpreter not initialized");
|
||||
return null;
|
||||
}
|
||||
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
|
||||
// Create TensorImage from image
|
||||
// Pre-process TensorImage
|
||||
var procImage = getProcessedImage(TensorImage.fromImage(image));
|
||||
|
||||
var preProcessElapsedTime =
|
||||
DateTime.now().millisecondsSinceEpoch - preProcessStart;
|
||||
if (procImage != null) {
|
||||
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
|
||||
// run inference
|
||||
var inferenceTimeElapsed =
|
||||
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
|
||||
|
||||
logger.i("Sending image to ML");
|
||||
|
||||
logger.i(procImage.buffer.asFloat32List());
|
||||
logger.i(procImage.width);
|
||||
logger.i(procImage.height);
|
||||
logger.i(procImage.tensorBuffer.shape);
|
||||
logger.i(procImage.tensorBuffer.isDynamic);
|
||||
_interpreter.run(procImage.buffer, _outputBuffer.getBuffer());
|
||||
|
||||
Map<String, double> labeledProb = TensorLabel.fromList(
|
||||
labels, _probabilityProcessor.process(_outputBuffer))
|
||||
.getMapWithFloatValue();
|
||||
final pred = getTopProbability(labeledProb);
|
||||
Recognition rec = Recognition(1, pred.key, pred.value);
|
||||
var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime;
|
||||
return {
|
||||
"recognitions": rec,
|
||||
"stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime),
|
||||
};
|
||||
} else {
|
||||
return null;
|
||||
}
|
||||
var preProcStart = DateTime.now().millisecondsSinceEpoch;
|
||||
_inputImage.loadImage(image);
|
||||
_inputImage = getProcessedImage(_inputImage)!;
|
||||
var inferenceStart = DateTime.now().millisecondsSinceEpoch;
|
||||
_interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
|
||||
var postProcStart = DateTime.now().millisecondsSinceEpoch;
|
||||
Map<String, double> labeledProb = TensorLabel.fromList(
|
||||
labels, _outputProcessor.process(_outputBuffer))
|
||||
.getMapWithFloatValue();
|
||||
final predictions = getTopProbabilities(labeledProb, number: 5)
|
||||
.mapIndexed(
|
||||
(index, element) => Recognition(index, element.key, element.value))
|
||||
.toList();
|
||||
var endTime = DateTime.now().millisecondsSinceEpoch;
|
||||
return {
|
||||
"recognitions": predictions,
|
||||
"stats": Stats(
|
||||
totalTime: endTime - preProcStart,
|
||||
preProcessingTime: inferenceStart - preProcStart,
|
||||
inferenceTime: postProcStart - inferenceStart,
|
||||
postProcessingTime: endTime - postProcStart,
|
||||
),
|
||||
};
|
||||
}
|
||||
|
||||
/// Gets the interpreter instance
|
||||
Interpreter get interpreter => _interpreter;
|
||||
|
||||
@@ -142,11 +118,12 @@ class Classifier {
|
||||
List<String> get labels => _labels;
|
||||
}
|
||||
|
||||
MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
|
||||
List<MapEntry<String, double>> getTopProbabilities(
|
||||
Map<String, double> labeledProb,
|
||||
{int number = 3}) {
|
||||
var pq = PriorityQueue<MapEntry<String, double>>(compare);
|
||||
pq.addAll(labeledProb.entries);
|
||||
|
||||
return pq.first;
|
||||
return [for (var i = 0; i < number; i += 1) pq.removeFirst()];
|
||||
}
|
||||
|
||||
int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {
|
||||
|
||||
@@ -0,0 +1,21 @@
|
||||
|
||||
/// Represents the recognition output from the model
|
||||
class Recognition {
|
||||
/// Index of the result
|
||||
final int _id;
|
||||
/// Label of the result
|
||||
final String _label;
|
||||
/// Confidence [0.0, 1.0]
|
||||
final double _score;
|
||||
|
||||
Recognition(this._id, this._label, this._score);
|
||||
|
||||
int get id => _id;
|
||||
String get label => _label;
|
||||
double get score => _score;
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'Recognition(id: $id, label: $label, score: $score)';
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
|
||||
class Stats {
|
||||
int totalTime;
|
||||
int preProcessingTime;
|
||||
int inferenceTime;
|
||||
int postProcessingTime;
|
||||
|
||||
Stats(
|
||||
{this.totalTime = -1,
|
||||
this.preProcessingTime = -1,
|
||||
this.inferenceTime = -1,
|
||||
this.postProcessingTime = -1});
|
||||
|
||||
@override
|
||||
String toString() {
|
||||
return 'Stats{totalPredictTime: $totalTime, preProcessingTime: $preProcessingTime, inferenceTime: $inferenceTime, postProcessingTime: $postProcessingTime}';
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
import 'dart:isolate';
|
||||
|
||||
import 'package:camera/camera.dart';
|
||||
import 'package:tensordex_mobile/tflite/classifier.dart';
|
||||
import 'package:tflite_flutter/tflite_flutter.dart';
|
||||
|
||||
import '../utils/image_utils.dart';
|
||||
import '../utils/logger.dart';
|
||||
|
||||
class IsolateBase {
|
||||
final ReceivePort _receivePort = ReceivePort();
|
||||
}
|
||||
|
||||
class MLIsolate extends IsolateBase {
|
||||
static const String debugIsolate = "MLIsolate";
|
||||
late SendPort _sendPort;
|
||||
|
||||
SendPort get sendPort => _sendPort;
|
||||
|
||||
Future<void> start() async {
|
||||
await Isolate.spawn<SendPort>(
|
||||
entryPoint,
|
||||
_receivePort.sendPort,
|
||||
debugName: debugIsolate,
|
||||
);
|
||||
_sendPort = await _receivePort.first;
|
||||
}
|
||||
|
||||
static void entryPoint(SendPort sendPort) async {
|
||||
final port = ReceivePort();
|
||||
sendPort.send(port.sendPort);
|
||||
|
||||
await for (final MLIsolateData mlIsolateData in port) {
|
||||
var cameraImage = mlIsolateData.cameraImage;
|
||||
var converted = ImageUtils.convertCameraImage(cameraImage);
|
||||
if (converted != null) {
|
||||
Classifier classifier = Classifier(
|
||||
interpreter:
|
||||
Interpreter.fromAddress(mlIsolateData.interpreterAddress),
|
||||
labels: mlIsolateData.labels);
|
||||
var result = classifier.predict(converted);
|
||||
mlIsolateData.responsePort?.send(result);
|
||||
} else {
|
||||
mlIsolateData.responsePort?.send({"response": "not working yet"});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Bundles data to pass between Isolate
|
||||
class MLIsolateData {
|
||||
CameraImage cameraImage;
|
||||
int interpreterAddress;
|
||||
List<String> labels;
|
||||
SendPort? responsePort;
|
||||
|
||||
MLIsolateData(
|
||||
this.cameraImage,
|
||||
this.interpreterAddress,
|
||||
this.labels,
|
||||
);
|
||||
}
|
||||
Reference in New Issue
Block a user