beginning work on ui

This commit is contained in:
Lucas Oskorep
2022-08-09 19:29:57 -04:00
parent d2e6e9a583
commit 2c0250c79d
5 changed files with 126 additions and 24 deletions
+11
View File
@@ -1,9 +1,20 @@
import 'package:camera/camera.dart';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:flutter/services.dart';
import 'package:tensordex_mobile/widgets/tensordex_home.dart'; import 'package:tensordex_mobile/widgets/tensordex_home.dart';
import 'package:tensordex_mobile/utils/logger.dart'; import 'package:tensordex_mobile/utils/logger.dart';
late List<CameraDescription> cameras;
Future<void> main() async { Future<void> main() async {
WidgetsFlutterBinding.ensureInitialized();
cameras = await availableCameras();
SystemChrome.setPreferredOrientations([
DeviceOrientation.portraitUp,
DeviceOrientation.portraitDown,
]).then((_) {
runApp(const MyApp()); runApp(const MyApp());
});
} }
class MyApp extends StatelessWidget { class MyApp extends StatelessWidget {
+12 -2
View File
@@ -1,7 +1,9 @@
import 'dart:io';
import 'dart:math'; import 'dart:math';
import 'package:collection/collection.dart'; import 'package:collection/collection.dart';
import 'package:image/image.dart' as image_lib; import 'package:image/image.dart' as image_lib;
import 'package:path_provider/path_provider.dart';
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter/tflite_flutter.dart';
import 'package:tflite_flutter_helper/tflite_flutter_helper.dart'; import 'package:tflite_flutter_helper/tflite_flutter_helper.dart';
@@ -31,6 +33,7 @@ class Classifier {
/// Labels file loaded as list /// Labels file loaded as list
late List<String> _labels; late List<String> _labels;
int classifierCreationStart = -1; int classifierCreationStart = -1;
bool _shouldReturnFrame = false;
Classifier( Classifier(
Interpreter interpreter, { Interpreter interpreter, {
@@ -67,6 +70,10 @@ class Classifier {
} }
} }
void setReturnFrame(bool returnFrame) {
_shouldReturnFrame = returnFrame;
}
/// Pre-process the image /// Pre-process the image
TensorImage? getProcessedImage(TensorImage? inputImage) { TensorImage? getProcessedImage(TensorImage? inputImage) {
int cropSize = min(_inputImage.height, _inputImage.width); int cropSize = min(_inputImage.height, _inputImage.width);
@@ -83,10 +90,12 @@ class Classifier {
} }
/// Runs object detection on the input image /// Runs object detection on the input image
Map<String, dynamic>? predict(image_lib.Image image) { Future<Map<String, dynamic>?> predict(image_lib.Image image) async {
var preProcStart = DateTime.now().millisecondsSinceEpoch; var preProcStart = DateTime.now().millisecondsSinceEpoch;
_inputImage.loadImage(image); _inputImage.loadImage(image_lib.copyRotate(image, 90));
_inputImage = getProcessedImage(_inputImage)!; _inputImage = getProcessedImage(_inputImage)!;
var inferenceStart = DateTime.now().millisecondsSinceEpoch; var inferenceStart = DateTime.now().millisecondsSinceEpoch;
_interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer()); _interpreter.run(_inputImage.buffer, _outputBuffer.getBuffer());
var postProcStart = DateTime.now().millisecondsSinceEpoch; var postProcStart = DateTime.now().millisecondsSinceEpoch;
@@ -106,6 +115,7 @@ class Classifier {
inferenceTime: postProcStart - inferenceStart, inferenceTime: postProcStart - inferenceStart,
postProcessingTime: endTime - postProcStart, postProcessingTime: endTime - postProcStart,
), ),
'image': _shouldReturnFrame? _inputImage.image : null,
}; };
} }
+11 -2
View File
@@ -5,6 +5,7 @@ import 'package:tensordex_mobile/tflite/classifier.dart';
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter/tflite_flutter.dart';
import '../utils/image_utils.dart'; import '../utils/image_utils.dart';
import '../utils/logger.dart';
class IsolateBase { class IsolateBase {
final ReceivePort _receivePort = ReceivePort(); final ReceivePort _receivePort = ReceivePort();
@@ -25,6 +26,7 @@ class MLIsolate extends IsolateBase {
_sendPort = await _receivePort.first; _sendPort = await _receivePort.first;
} }
static void entryPoint(SendPort sendPort) async { static void entryPoint(SendPort sendPort) async {
final port = ReceivePort(); final port = ReceivePort();
sendPort.send(port.sendPort); sendPort.send(port.sendPort);
@@ -33,10 +35,15 @@ class MLIsolate extends IsolateBase {
var cameraImage = mlIsolateData.cameraImage; var cameraImage = mlIsolateData.cameraImage;
var converted = ImageUtils.convertCameraImage(cameraImage); var converted = ImageUtils.convertCameraImage(cameraImage);
if (converted != null) { if (converted != null) {
Classifier classifier = Classifier( var classifier = Classifier(
Interpreter.fromAddress(mlIsolateData.interpreterAddress), Interpreter.fromAddress(mlIsolateData.interpreterAddress),
labels: mlIsolateData.labels); labels: mlIsolateData.labels);
var result = classifier.predict(converted); if (classifier.interpreter.address !=
mlIsolateData.interpreterAddress) {
logger.e('INTERPRETER ADDRESS MISMATCH!');
}
classifier.setReturnFrame(mlIsolateData.shouldSaveFrame);
var result = await classifier.predict(converted);
mlIsolateData.responsePort?.send(result); mlIsolateData.responsePort?.send(result);
} else { } else {
mlIsolateData.responsePort?.send({'response': 'not working yet'}); mlIsolateData.responsePort?.send({'response': 'not working yet'});
@@ -49,6 +56,7 @@ class MLIsolate extends IsolateBase {
class MLIsolateData { class MLIsolateData {
CameraImage cameraImage; CameraImage cameraImage;
int interpreterAddress; int interpreterAddress;
bool shouldSaveFrame;
List<String> labels; List<String> labels;
SendPort? responsePort; SendPort? responsePort;
@@ -56,5 +64,6 @@ class MLIsolateData {
this.cameraImage, this.cameraImage,
this.interpreterAddress, this.interpreterAddress,
this.labels, this.labels,
this.shouldSaveFrame,
); );
} }
+61 -17
View File
@@ -1,19 +1,21 @@
import 'dart:convert'; import 'dart:convert';
import 'dart:io';
import 'dart:isolate'; import 'dart:isolate';
import 'package:camera/camera.dart'; import 'package:camera/camera.dart';
import 'package:flutter/material.dart'; import 'package:flutter/material.dart';
import 'package:flutter/services.dart'; import 'package:flutter/services.dart';
import 'package:image/image.dart';
import 'package:path_provider/path_provider.dart';
import 'package:tensordex_mobile/tflite/ml_isolate.dart'; import 'package:tensordex_mobile/tflite/ml_isolate.dart';
import 'package:tensordex_mobile/tflite/model/configuration.dart'; import 'package:tensordex_mobile/tflite/model/configuration.dart';
import 'package:tensordex_mobile/tflite/model/outputs/stats.dart'; import 'package:tensordex_mobile/tflite/model/outputs/stats.dart';
import 'package:tflite_flutter/tflite_flutter.dart'; import 'package:tflite_flutter/tflite_flutter.dart';
import '../main.dart';
import '../tflite/classifier.dart'; import '../tflite/classifier.dart';
import '../tflite/model/outputs/recognition.dart'; import '../tflite/model/outputs/recognition.dart';
import '../utils/logger.dart'; import '../utils/logger.dart';
class PokeFinder extends StatefulWidget { class PokeFinder extends StatefulWidget {
/// Callback to pass results after inference to [HomeView] /// Callback to pass results after inference to [HomeView]
final Function(List<Recognition> recognitions) resultsCallback; final Function(List<Recognition> recognitions) resultsCallback;
@@ -35,9 +37,9 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
bool predicting = false; bool predicting = false;
bool _cameraInitialized = false; bool _cameraInitialized = false;
bool _classifierInitialized = false; bool _classifierInitialized = false;
bool _saveClassifierImage = false;
int cameraIndex = 0;
//cameras
late List<CameraDescription> cameras;
late CameraController cameraController; late CameraController cameraController;
//ml variables //ml variables
@@ -56,7 +58,10 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
WidgetsBinding.instance.addObserver(this); WidgetsBinding.instance.addObserver(this);
_mlIsolate = MLIsolate(); _mlIsolate = MLIsolate();
await _mlIsolate.start(); await _mlIsolate.start();
initializeCamera(); swapToCamera(cameras[0]);
for (CameraDescription cam in cameras) {
logger.i(cam);
}
initializeModel(); initializeModel();
predicting = false; predicting = false;
} }
@@ -86,16 +91,12 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
options: config.interpreters[0]); options: config.interpreters[0]);
} }
/// Initializes the camera by setting [cameraController] void swapToCamera(CameraDescription cameraDescription) async {
void initializeCamera() async { cameraController = CameraController(cameraDescription, ResolutionPreset.low,
cameras = await availableCameras(); enableAudio: false);
cameraController =
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
cameraController.initialize().then((_) async { cameraController.initialize().then((_) async {
/// previewSize is size of each image frame captured by controller /// previewSize is size of each image frame captured by controller
/// 352x288 on iOS, 240p (320x240) on Android with ResolutionPreset.low /// 352x288 on iOS, 240p (320x240) on Android with ResolutionPreset.low
// Stream of image passed to [onLatestImageAvailable] callback
await cameraController.startImageStream(onLatestImageAvailable); await cameraController.startImageStream(onLatestImageAvailable);
setState(() { setState(() {
_cameraInitialized = true; _cameraInitialized = true;
@@ -112,8 +113,12 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
setState(() { setState(() {
predicting = true; predicting = true;
}); });
logger.i(_saveClassifierImage);
var results = await inference(MLIsolateData( var results = await inference(MLIsolateData(
cameraImage, classifier.interpreter.address, classifier.labels)); cameraImage,
classifier.interpreter.address,
classifier.labels,
_saveClassifierImage));
if (results.containsKey('recognitions')) { if (results.containsKey('recognitions')) {
widget.resultsCallback(results['recognitions']); widget.resultsCallback(results['recognitions']);
@@ -121,23 +126,62 @@ class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
if (results.containsKey('stats')) { if (results.containsKey('stats')) {
widget.statsCallback(results['stats']); widget.statsCallback(results['stats']);
} }
logger.i(results); if (results.containsKey('image')) {
var image = results['image'];
if (image != null) {
Directory tempDir = await getTemporaryDirectory();
String tempPath = tempDir.path;
logger.i(tempPath);
logger.i('SAVING IMAGE!');
await File('$tempPath/${DateTime.now().millisecondsSinceEpoch}.png')
.writeAsBytes(encodePng(image));
_saveClassifierImage = false;
}
}
setState(() { setState(() {
predicting = false; predicting = false;
}); });
} }
} }
void swapCamera() async {
logger.i(cameras);
logger.i(cameraIndex);
cameraIndex += 1;
if (cameras.length <= cameraIndex) {
cameraIndex = 0;
}
swapToCamera(cameras[cameraIndex]);
}
void saveMLImage() async {
logger.i('setting save classifier to true');
_saveClassifierImage = true;
}
void setZoom() async {
logger.i(await cameraController.getMinZoomLevel());
logger.i(await cameraController.getMaxZoomLevel());
logger.i(cameraController.setZoomLevel(2.0));
}
@override @override
Widget build(BuildContext context) { Widget build(BuildContext context) {
// Return empty container while the camera is not initialized // Return empty container while the camera is not initialized
if (!_cameraInitialized) { if (!_cameraInitialized) {
return Container(); return Container();
} }
return AspectRatio( return Column(
children: [
AspectRatio(
aspectRatio: 1 / cameraController.value.aspectRatio, aspectRatio: 1 / cameraController.value.aspectRatio,
child: CameraPreview(cameraController)); child: CameraPreview(cameraController)),
TextButton(onPressed: swapCamera, child: const Text('Change Camera!')),
TextButton(
onPressed: saveMLImage, child: const Text('Save Model Image')),
TextButton(onPressed: setZoom, child: const Text('Zoom!'))
],
);
} }
/// Runs inference in another isolate /// Runs inference in another isolate
+28
View File
@@ -64,6 +64,34 @@ class _TensordexHomeState extends State<TensordexHome> {
appBar: AppBar( appBar: AppBar(
title: Text(widget.title), title: Text(widget.title),
), ),
drawer: Drawer(
child: ListView(
// Important: Remove any padding from the ListView.
padding: EdgeInsets.zero,
children: [
const DrawerHeader(
decoration: BoxDecoration(
color: Colors.blue,
),
child: Text('Drawer Header'),
),
ListTile(
title: const Text('Item 1'),
onTap: () {
// Update the state of the app.
// ...
},
),
ListTile(
title: const Text('Item 2'),
onTap: () {
// Update the state of the app.
// ...
},
),
],
),
),
body: Center( body: Center(
child: Column( child: Column(
mainAxisAlignment: MainAxisAlignment.start, mainAxisAlignment: MainAxisAlignment.start,