001/*
002 * Copyright (C) Photon Vision.
003 *
004 * This program is free software: you can redistribute it and/or modify
005 * it under the terms of the GNU General Public License as published by
006 * the Free Software Foundation, either version 3 of the License, or
007 * (at your option) any later version.
008 *
009 * This program is distributed in the hope that it will be useful,
010 * but WITHOUT ANY WARRANTY; without even the implied warranty of
011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
012 * GNU General Public License for more details.
013 *
014 * You should have received a copy of the GNU General Public License
015 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
016 */
017
018package org.photonvision.vision.objects;
019
020import java.io.File;
021import java.io.IOException;
022import java.nio.file.Files;
023import java.nio.file.Paths;
024import java.util.List;
025import org.opencv.core.Size;
026import org.photonvision.common.configuration.NeuralNetworkModelManager;
027import org.photonvision.jni.RknnObjectDetector;
028import org.photonvision.rknn.RknnJNI;
029
030public class RknnModel implements Model {
031    public final File modelFile;
032    public final RknnJNI.ModelVersion version;
033    public final List<String> labels;
034    public final Size inputSize;
035
036    /**
037     * Determines the model version based on the model's filename.
038     *
039     * <p>"yolov5" -> "YOLO_V5"
040     *
041     * <p>"yolov8" -> "YOLO_V8"
042     *
043     * <p>"yolov11" -> "YOLO_V11"
044     *
045     * @param modelName The model's filename
046     * @return The model version
047     */
048    private static RknnJNI.ModelVersion getModelVersion(String modelName)
049            throws IllegalArgumentException {
050        if (modelName.contains("yolov5")) {
051            return RknnJNI.ModelVersion.YOLO_V5;
052        } else if (modelName.contains("yolov8")) {
053            return RknnJNI.ModelVersion.YOLO_V8;
054        } else if (modelName.contains("yolov11")) {
055            return RknnJNI.ModelVersion.YOLO_V11;
056        } else {
057            throw new IllegalArgumentException("Unknown model version for model " + modelName);
058        }
059    }
060
061    /**
062     * rknn model constructor.
063     *
064     * @param modelFile path to model on disk. Format: `name-width-height-model.rknn`
065     * @param labels path to labels file on disk
066     * @throws IllegalArgumentException
067     */
068    public RknnModel(File modelFile, String labels) throws IllegalArgumentException, IOException {
069        this.modelFile = modelFile;
070
071        // parseRKNNName throws an IllegalArgumentException if the model name is invalid
072        String[] parts = NeuralNetworkModelManager.parseRKNNName(modelFile.getName());
073
074        this.version = getModelVersion(parts[3]);
075
076        int width = Integer.parseInt(parts[1]);
077        int height = Integer.parseInt(parts[2]);
078        this.inputSize = new Size(width, height);
079
080        try {
081            this.labels = Files.readAllLines(Paths.get(labels));
082        } catch (IOException e) {
083            throw new IllegalArgumentException("Failed to read labels file " + labels, e);
084        }
085    }
086
087    public String getName() {
088        return modelFile.getName();
089    }
090
091    public ObjectDetector load() {
092        return new RknnObjectDetector(this, inputSize);
093    }
094}