fixed classifier and added in a preliminary results view that shows what pokemon are currently being looked at.

This commit is contained in:
Lucas Oskorep
2022-06-22 21:44:15 -04:00
parent ebfbfb503d
commit 9ec737db46
10 changed files with 305 additions and 355 deletions
+57 -80
View File
@@ -1,42 +1,35 @@
import 'dart:math';
import 'dart:ui';
import 'package:collection/collection.dart';
import 'package:image/image.dart' as image_lib;
import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
import '../utils/logger.dart';
import '../utils/recognition.dart';
import '../utils/stats.dart';
import 'data/recognition.dart';
import 'data/stats.dart';
/// Classifier
class Classifier {
static const String MODEL_FILE_NAME = "detect.tflite";
static const String LABEL_FILE_NAME = "labelmap.txt";
/// Input size of image (height = width = 300)
static const int INPUT_SIZE = 224;
/// Result score threshold
static const double THRESHOLD = 0.5;
static const String modelFileName = "efficientnet_v2s.tflite";
static const int inputSize = 224;
/// [ImageProcessor] used to pre-process the image
ImageProcessor? imageProcessor;
/// Padding the image to transform into square
// int padSize = 0;
///Tensor image to move image data into
late TensorImage _inputImage;
/// Instance of Interpreter
late Interpreter _interpreter;
late TensorBuffer _outputBuffer;
late var _probabilityProcessor;
late TfLiteType _inputType;
late TfLiteType _outputType;
late SequentialProcessor<TensorBuffer> _outputProcessor;
/// Labels file loaded as list
late List<String> _labels;
int classifierCreationStart = -1;
/// Number of results to show
static const int NUM_RESULTS = 10;
Classifier({
Interpreter? interpreter,
@@ -51,19 +44,18 @@ class Classifier {
try {
_interpreter = interpreter ??
await Interpreter.fromAsset(
MODEL_FILE_NAME,
options: InterpreterOptions()..threads = 4,
modelFileName,
options: InterpreterOptions()..threads = 8,
);
var outputTensor = _interpreter.getOutputTensor(0);
var outputShape = outputTensor.shape;
var outputType = outputTensor.type;
_outputType = outputTensor.type;
var inputTensor = _interpreter.getInputTensor(0);
var intputShape = inputTensor.shape;
var intputType = inputTensor.type;
_outputBuffer = TensorBuffer.createFixedSize(outputShape, outputType);
_probabilityProcessor =
// var intputShape = inputTensor.shape;
_inputType = inputTensor.type;
_inputImage = TensorImage(_inputType);
_outputBuffer = TensorBuffer.createFixedSize(outputShape, _outputType);
_outputProcessor =
TensorProcessorBuilder().add(NormalizeOp(0, 1)).build();
} catch (e) {
logger.e("Error while creating interpreter: ", e);
@@ -80,61 +72,45 @@ class Classifier {
}
/// Pre-process the image
TensorImage? getProcessedImage(TensorImage inputImage) {
TensorImage? getProcessedImage(TensorImage? inputImage) {
// padSize = max(inputImage.height, inputImage.width);
imageProcessor ??= ImageProcessorBuilder()
// .add(ResizeWithCropOrPadOp(padSize, padSize))
.add(ResizeOp(INPUT_SIZE, INPUT_SIZE, ResizeMethod.BILINEAR))
.add(NormalizeOp(127.5, 127.5))
.build();
return imageProcessor?.process(inputImage);
if (inputImage != null) {
imageProcessor ??= ImageProcessorBuilder()
.add(ResizeWithCropOrPadOp(224, 224))
.add(ResizeOp(inputSize, inputSize, ResizeMethod.BILINEAR))
// .add(NormalizeOp(127.5, 127.5))
.build();
return imageProcessor?.process(inputImage);
}
return null;
}
/// Runs object detection on the input image
Map<String, dynamic>? predict(image_lib.Image image) {
logger.i(labels);
var predictStartTime = DateTime.now().millisecondsSinceEpoch;
if (_interpreter == null) {
logger.e("Interpreter not initialized");
return null;
}
var preProcessStart = DateTime.now().millisecondsSinceEpoch;
// Create TensorImage from image
// Pre-process TensorImage
var procImage = getProcessedImage(TensorImage.fromImage(image));
var preProcessElapsedTime =
DateTime.now().millisecondsSinceEpoch - preProcessStart;
if (procImage != null) {
var inferenceTimeStart = DateTime.now().millisecondsSinceEpoch;
// run inference
var inferenceTimeElapsed =
DateTime.now().millisecondsSinceEpoch - inferenceTimeStart;
logger.i("Sending image to ML");
logger.i(procImage.buffer.asFloat32List());
logger.i(procImage.width);
logger.i(procImage.height);
logger.i(procImage.tensorBuffer.shape);
logger.i(procImage.tensorBuffer.isDynamic);
_interpreter.run(procImage.buffer, _outputBuffer.getBuffer());
Map<String, double> labeledProb = TensorLabel.fromList(
labels, _probabilityProcessor.process(_outputBuffer))
.getMapWithFloatValue();
final pred = getTopProbability(labeledProb);
Recognition rec = Recognition(1, pred.key, pred.value);
var predictElapsedTime = DateTime.now().millisecondsSinceEpoch - predictStartTime;
return {
"recognitions": rec,
"stats": Stats(predictElapsedTime, predictElapsedTime, predictElapsedTime, predictElapsedTime),
};
} else {
return null;
}
var preProcStart = DateTime.now().millisecondsSinceEpoch;
_inputImage.loadImage(image);
_inputImage = getProcessedImage(_inputImage)!;
var inferenceStart = DateTime.now().millisecondsSinceEpoch;
_interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
var postProcStart = DateTime.now().millisecondsSinceEpoch;
Map<String, double> labeledProb = TensorLabel.fromList(
labels, _outputProcessor.process(_outputBuffer))
.getMapWithFloatValue();
final predictions = getTopProbabilities(labeledProb, number: 5)
.mapIndexed(
(index, element) => Recognition(index, element.key, element.value))
.toList();
var endTime = DateTime.now().millisecondsSinceEpoch;
return {
"recognitions": predictions,
"stats": Stats(
totalTime: endTime - preProcStart,
preProcessingTime: inferenceStart - preProcStart,
inferenceTime: postProcStart - inferenceStart,
postProcessingTime: endTime - postProcStart,
),
};
}
/// Gets the interpreter instance
Interpreter get interpreter => _interpreter;
@@ -142,11 +118,12 @@ class Classifier {
List<String> get labels => _labels;
}
MapEntry<String, double> getTopProbability(Map<String, double> labeledProb) {
List<MapEntry<String, double>> getTopProbabilities(
Map<String, double> labeledProb,
{int number = 3}) {
var pq = PriorityQueue<MapEntry<String, double>>(compare);
pq.addAll(labeledProb.entries);
return pq.first;
return [for (var i = 0; i < number; i += 1) pq.removeFirst()];
}
int compare(MapEntry<String, double> e1, MapEntry<String, double> e2) {