refactoring to a widget dir

This commit is contained in:
Lucas Oskorep
2022-06-22 22:02:45 -04:00
parent 9ec737db46
commit 8044485d8c
5 changed files with 23 additions and 31 deletions
+152
View File
@@ -0,0 +1,152 @@
import 'dart:isolate';
import 'package:camera/camera.dart';
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/tflite/ml_isolate.dart';
import 'package:tflite_flutter/tflite_flutter.dart';
import '../tflite/classifier.dart';
import '../utils/logger.dart';
import '../tflite/data/recognition.dart';
import '../tflite/data/stats.dart';
/// [PokeFinder] sends each frame for inference
class PokeFinder extends StatefulWidget {
/// Callback to pass results after inference to [HomeView]
final Function(List<Recognition> recognitions) resultsCallback;
/// Callback to inference stats to [HomeView]
final Function(Stats stats) statsCallback;
/// Constructor
const PokeFinder(
{Key? key, required this.resultsCallback, required this.statsCallback})
: super(key: key);
@override
State<PokeFinder> createState() => _PokeFinderState();
}
class _PokeFinderState extends State<PokeFinder> with WidgetsBindingObserver {
late List<CameraDescription> cameras;
late CameraController cameraController;
late MLIsolate _mlIsolate;
/// true when inference is ongoing
bool predicting = false;
bool _cameraInitialized = false;
bool _classifierInitialized = false;
late Interpreter interpreter;
late Classifier classifier;
@override
void initState() {
initStateAsync();
super.initState();
}
void initStateAsync() async {
WidgetsBinding.instance.addObserver(this);
_mlIsolate = MLIsolate();
await _mlIsolate.start();
initializeCamera();
initializeModel();
predicting = false;
}
void initializeModel() async {
var interpreterOptions = InterpreterOptions()..threads = 8;
interpreter = await Interpreter.fromAsset('efficientnet_v2s.tflite',
options: interpreterOptions);
classifier = Classifier(interpreter: interpreter);
_classifierInitialized = true;
}
/// Initializes the camera by setting [cameraController]
void initializeCamera() async {
cameras = await availableCameras();
// cameras[0] for rear-camera
cameraController =
CameraController(cameras[0], ResolutionPreset.low, enableAudio: false);
cameraController.initialize().then((_) async {
/// previewSize is size of each image frame captured by controller
/// 352x288 on iOS, 240p (320x240) on Android with ResolutionPreset.low
// Stream of image passed to [onLatestImageAvailable] callback
await cameraController.startImageStream(onLatestImageAvailable);
setState(() {
_cameraInitialized = true;
});
});
}
/// Callback to receive each frame [CameraImage] perform inference on it
onLatestImageAvailable(CameraImage cameraImage) async {
if (_classifierInitialized) {
if (predicting) {
return;
}
setState(() {
predicting = true;
});
var results = await inference(MLIsolateData(
cameraImage, classifier.interpreter.address, classifier.labels));
if (results.containsKey("recognitions")) {
widget.resultsCallback(results["recognitions"]);
}
if (results.containsKey("stats")) {
widget.statsCallback(results["stats"]);
}
logger.i(results);
setState(() {
predicting = false;
});
}
}
@override
Widget build(BuildContext context) {
// Return empty container while the camera is not initialized
if (!_cameraInitialized) {
return Container();
}
return AspectRatio(
aspectRatio: 1 / cameraController.value.aspectRatio,
child: CameraPreview(cameraController));
}
/// Runs inference in another isolate
Future<Map<String, dynamic>> inference(MLIsolateData mlIsolateData) async {
ReceivePort responsePort = ReceivePort();
_mlIsolate.sendPort
.send(mlIsolateData..responsePort = responsePort.sendPort);
var results = await responsePort.first;
return results;
}
@override
void didChangeAppLifecycleState(AppLifecycleState state) async {
switch (state) {
case AppLifecycleState.paused:
cameraController.stopImageStream();
break;
case AppLifecycleState.resumed:
if (!cameraController.value.isStreamingImages) {
await cameraController.startImageStream(onLatestImageAvailable);
}
break;
default:
}
}
@override
void dispose() {
WidgetsBinding.instance.removeObserver(this);
cameraController.dispose();
super.dispose();
}
}
+28
View File
@@ -0,0 +1,28 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart';
import 'package:tensordex_mobile/tflite/data/recognition.dart';
import 'package:tensordex_mobile/tflite/data/stats.dart';
/// [PokeFinder] sends each frame for inference
class Results extends StatefulWidget {
final List<Recognition> recognitions;
final Stats stats;
/// Constructor
const Results(this.recognitions, this.stats, {Key? key}) : super(key: key);
@override
State<Results> createState() => _ResultsState();
}
class _ResultsState extends State<Results> {
@override
void initState() {
super.initState();
}
@override
Widget build(BuildContext context) {
return Text(widget.recognitions.toString());
}
}
+89
View File
@@ -0,0 +1,89 @@
import 'package:flutter/material.dart';
import 'package:tensordex_mobile/widgets/poke_finder.dart';
import 'package:tensordex_mobile/widgets/results.dart';
import '../utils/logger.dart';
import '../tflite/data/recognition.dart';
import '../tflite/data/stats.dart';
class TensordexHome extends StatefulWidget {
const TensordexHome({Key? key, required this.title}) : super(key: key);
// This class is the configuration for the state. It holds the values (in this
// case the title) provided by the parent (in this case the App widget) and
// used by the build method of the State. Fields in a Widget subclass are
// always marked "final".
final String title;
@override
State<TensordexHome> createState() => _TensordexHomeState();
}
class _TensordexHomeState extends State<TensordexHome> {
/// Results from the image classifier
List<Recognition> results = [Recognition(1, "NOTHING DETECTED", .5)];
Stats stats = Stats();
/// Scaffold Key
GlobalKey<ScaffoldState> scaffoldKey = GlobalKey();
void _incrementCounter() {
setState(() {
logger.d("Counter Incremented!");
});
}
@override
void initState() {
super.initState();
}
@override
void dispose() {
super.dispose();
}
/// Callback to get inference results from [PokeFinder]
void resultsCallback(List<Recognition> results) {
setState(() {
this.results = results;
});
}
/// Callback to get inference stats from [PokeFinder]
void statsCallback(Stats stats) {
setState(() {
this.stats = stats;
});
}
@override
Widget build(BuildContext context) {
return Scaffold(
appBar: AppBar(
title: Text(widget.title),
),
body: Center(
child: Column(
mainAxisAlignment: MainAxisAlignment.start,
children: <Widget>[
PokeFinder(
resultsCallback: resultsCallback,
statsCallback: statsCallback),
Results(results, stats),
],
),
),
floatingActionButton: GestureDetector(
onLongPress: () {
_incrementCounter();
},
child: FloatingActionButton(
onPressed: _incrementCounter,
tooltip: 'Increment',
child: const Icon(Icons.photo_camera),
), // This trailing comma makes auto-formatting nicer for build methods.
));
}
}