001/*
002 * Copyright (C) Photon Vision.
003 *
004 * This program is free software: you can redistribute it and/or modify
005 * it under the terms of the GNU General Public License as published by
006 * the Free Software Foundation, either version 3 of the License, or
007 * (at your option) any later version.
008 *
009 * This program is distributed in the hope that it will be useful,
010 * but WITHOUT ANY WARRANTY; without even the implied warranty of
011 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
012 * GNU General Public License for more details.
013 *
014 * You should have received a copy of the GNU General Public License
015 * along with this program.  If not, see <https://www.gnu.org/licenses/>.
016 */
017
018package org.photonvision.common.configuration;
019
020import java.io.File;
021import java.io.IOException;
022import java.io.InputStream;
023import java.net.URISyntaxException;
024import java.nio.file.Files;
025import java.nio.file.Path;
026import java.nio.file.StandardCopyOption;
027import java.util.ArrayList;
028import java.util.Arrays;
029import java.util.Enumeration;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033import java.util.Optional;
034import java.util.jar.JarEntry;
035import java.util.jar.JarFile;
036import java.util.regex.Matcher;
037import java.util.regex.Pattern;
038import org.photonvision.common.hardware.Platform;
039import org.photonvision.common.logging.LogGroup;
040import org.photonvision.common.logging.Logger;
041import org.photonvision.vision.objects.Model;
042import org.photonvision.vision.objects.RknnModel;
043
044/**
045 * Manages the loading of neural network models.
046 *
047 * <p>Models are loaded from the filesystem at the <code>modelsFolder</code> location. PhotonVision
048 * also supports shipping pre-trained models as resources in the JAR. If the model has already been
049 * extracted to the filesystem, it will not be extracted again.
050 *
051 * <p>Each model must have a corresponding <code>labels</code> file. The labels file format is
052 * simply a list of string names per label, one label per line. The labels file must have the same
053 * name as the model file, but with the suffix <code>-labels.txt</code> instead of <code>.rknn
054 * </code>.
055 */
056public class NeuralNetworkModelManager {
057    /** Singleton instance of the NeuralNetworkModelManager */
058    private static NeuralNetworkModelManager INSTANCE;
059
060    /**
061     * Private constructor to prevent instantiation
062     *
063     * @return The NeuralNetworkModelManager instance
064     */
065    private NeuralNetworkModelManager() {
066        ArrayList<NeuralNetworkBackend> backends = new ArrayList<>();
067
068        if (Platform.isRK3588()) {
069            backends.add(NeuralNetworkBackend.RKNN);
070        }
071
072        supportedBackends = backends;
073    }
074
075    /**
076     * Returns the singleton instance of the NeuralNetworkModelManager
077     *
078     * @return The singleton instance
079     */
080    public static NeuralNetworkModelManager getInstance() {
081        if (INSTANCE == null) {
082            INSTANCE = new NeuralNetworkModelManager();
083        }
084        return INSTANCE;
085    }
086
087    /** Logger for the NeuralNetworkModelManager */
088    private static final Logger logger = new Logger(NeuralNetworkModelManager.class, LogGroup.Config);
089
090    public enum NeuralNetworkBackend {
091        RKNN(".rknn");
092
093        private String format;
094
095        private NeuralNetworkBackend(String format) {
096            this.format = format;
097        }
098    }
099
100    private final List<NeuralNetworkBackend> supportedBackends;
101
102    /**
103     * Retrieves the list of supported backends.
104     *
105     * @return the list
106     */
107    public List<String> getSupportedBackends() {
108        return supportedBackends.stream().map(Enum::toString).toList();
109    }
110
111    /**
112     * Stores model information, such as the model file, labels, and version.
113     *
114     * <p>The first model in the list is the default model.
115     */
116    private Map<NeuralNetworkBackend, ArrayList<Model>> models;
117
118    /**
119     * Retrieves the deep neural network models available, in a format that can be used by the
120     * frontend.
121     *
122     * @return A map containing the available models, where the key is the backend and the value is a
123     *     list of model names.
124     */
125    public HashMap<String, ArrayList<String>> getModels() {
126        HashMap<String, ArrayList<String>> modelMap = new HashMap<>();
127        if (models == null) {
128            return modelMap;
129        }
130
131        models.forEach(
132                (backend, backendModels) -> {
133                    ArrayList<String> modelNames = new ArrayList<>();
134                    backendModels.forEach(model -> modelNames.add(model.getName()));
135                    modelMap.put(backend.toString(), modelNames);
136                });
137
138        return modelMap;
139    }
140
141    /**
142     * Retrieves the model with the specified name, assuming it is available under a supported
143     * backend.
144     *
145     * <p>If this method returns `Optional.of(..)` then the model should be safe to load.
146     *
147     * @param modelName the name of the model to retrieve
148     * @return an Optional containing the model if found, or an empty Optional if not found
149     */
150    public Optional<Model> getModel(String modelName) {
151        if (models == null) {
152            return Optional.empty();
153        }
154
155        // Check if the model exists in any supported backend
156        for (NeuralNetworkBackend backend : supportedBackends) {
157            if (models.containsKey(backend)) {
158                Optional<Model> model =
159                        models.get(backend).stream().filter(m -> m.getName().equals(modelName)).findFirst();
160                if (model.isPresent()) {
161                    return model;
162                }
163            }
164        }
165
166        return Optional.empty();
167    }
168
169    /** The default model when no model is specified. */
170    public Optional<Model> getDefaultModel() {
171        if (models == null) {
172            return Optional.empty();
173        }
174
175        if (supportedBackends.isEmpty()) {
176            return Optional.empty();
177        }
178
179        return models.get(supportedBackends.get(0)).stream().findFirst();
180    }
181
182    private void loadModel(File model) {
183        if (models == null) {
184            models = new HashMap<>();
185        }
186
187        // Get the model extension and check if it is supported
188        String modelExtension = model.getName().substring(model.getName().lastIndexOf('.'));
189        if (modelExtension.equals(".txt")) {
190            return;
191        }
192
193        Optional<NeuralNetworkBackend> backend =
194                Arrays.stream(NeuralNetworkBackend.values())
195                        .filter(b -> b.format.equals(modelExtension))
196                        .findFirst();
197
198        if (!backend.isPresent()) {
199            logger.warn("Model " + model.getName() + " has an unknown extension.");
200            return;
201        }
202
203        String labels = model.getAbsolutePath().replace(backend.get().format, "-labels.txt");
204        if (!models.containsKey(backend.get())) {
205            models.put(backend.get(), new ArrayList<>());
206        }
207
208        try {
209            switch (backend.get()) {
210                case RKNN -> {
211                    models.get(backend.get()).add(new RknnModel(model, labels));
212                    logger.info(
213                            "Loaded model " + model.getName() + " for backend " + backend.get().toString());
214                }
215            }
216        } catch (IllegalArgumentException e) {
217            logger.error("Failed to load model " + model.getName(), e);
218        } catch (IOException e) {
219            logger.error("Failed to read labels for model " + model.getName(), e);
220        }
221    }
222
223    /**
224     * Discovers DNN models from the specified folder.
225     *
226     * @param modelsDirectory The folder where the models are stored
227     */
228    public void discoverModels(File modelsDirectory) {
229        logger.info("Supported backends: " + supportedBackends);
230
231        if (!modelsDirectory.exists()) {
232            logger.error("Models folder " + modelsDirectory.getAbsolutePath() + " does not exist.");
233            return;
234        }
235
236        models = new HashMap<>();
237
238        try {
239            Files.walk(modelsDirectory.toPath())
240                    .filter(Files::isRegularFile)
241                    .forEach(path -> loadModel(path.toFile()));
242        } catch (IOException e) {
243            logger.error("Failed to discover models at " + modelsDirectory.getAbsolutePath(), e);
244        }
245
246        // After loading all of the models, sort them by name to ensure a consistent
247        // ordering
248        models.forEach(
249                (backend, backendModels) ->
250                        backendModels.sort((a, b) -> a.getName().compareTo(b.getName())));
251
252        // Log
253        StringBuilder sb = new StringBuilder();
254        sb.append("Discovered models: ");
255        models.forEach(
256                (backend, backendModels) -> {
257                    sb.append(backend).append(" [");
258                    backendModels.forEach(model -> sb.append(model.getName()).append(", "));
259                    sb.append("] ");
260                });
261    }
262
263    /**
264     * Extracts models from the JAR and copies them to disk.
265     *
266     * @param modelsDirectory the directory on disk to save models
267     */
268    public void extractModels(File modelsDirectory) {
269        if (!modelsDirectory.exists() && !modelsDirectory.mkdirs()) {
270            throw new RuntimeException("Failed to create directory: " + modelsDirectory);
271        }
272
273        String resource = "models";
274
275        try {
276            String jarPath =
277                    getClass().getProtectionDomain().getCodeSource().getLocation().toURI().getPath();
278            try (JarFile jarFile = new JarFile(jarPath)) {
279                Enumeration<JarEntry> entries = jarFile.entries();
280                while (entries.hasMoreElements()) {
281                    JarEntry entry = entries.nextElement();
282                    if (!entry.getName().startsWith(resource + "/") || entry.isDirectory()) {
283                        continue;
284                    }
285                    Path outputPath =
286                            modelsDirectory.toPath().resolve(entry.getName().substring(resource.length() + 1));
287
288                    if (Files.exists(outputPath)) {
289                        logger.info("Skipping extraction of DNN resource: " + entry.getName());
290                        continue;
291                    }
292
293                    Files.createDirectories(outputPath.getParent());
294                    try (InputStream inputStream = jarFile.getInputStream(entry)) {
295                        Files.copy(inputStream, outputPath, StandardCopyOption.REPLACE_EXISTING);
296                        logger.info("Extracted DNN resource: " + entry.getName());
297                    } catch (IOException e) {
298                        logger.error("Failed to extract DNN resource: " + entry.getName(), e);
299                    }
300                }
301            }
302        } catch (IOException | URISyntaxException e) {
303            logger.error("Error extracting models", e);
304        }
305    }
306
307    private static Pattern modelPattern =
308            Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)\\.rknn$");
309
310    private static Pattern labelsPattern =
311            Pattern.compile("^([a-zA-Z0-9._]+)-(\\d+)-(\\d+)-(yolov(?:5|8|11)[nsmlx]*)-labels\\.txt$");
312
313    /**
314     * Check naming conventions for models and labels.
315     *
316     * <p>This is static as it is not dependent on the state of the class.
317     *
318     * @param modelName the name of the model
319     * @param labelsName the name of the labels file
320     * @throws IllegalArgumentException if the names are invalid
321     */
322    public static void verifyRKNNNames(String modelName, String labelsName) {
323        // check null
324        if (modelName == null || labelsName == null) {
325            throw new IllegalArgumentException("Model name and labels name cannot be null");
326        }
327
328        // These patterns check that the naming convention of
329        // name-widthResolution-heightResolution-modelType is followed
330
331        Matcher modelMatcher = modelPattern.matcher(modelName);
332        Matcher labelsMatcher = labelsPattern.matcher(labelsName);
333
334        if (!modelMatcher.matches() || !labelsMatcher.matches()) {
335            throw new IllegalArgumentException(
336                    "Model name and labels name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn and name-widthResolution-heightResolution-modelType-labels.txt");
337        }
338
339        if (!modelMatcher.group(1).equals(labelsMatcher.group(1))
340                || !modelMatcher.group(2).equals(labelsMatcher.group(2))
341                || !modelMatcher.group(3).equals(labelsMatcher.group(3))
342                || !modelMatcher.group(4).equals(labelsMatcher.group(4))) {
343            throw new IllegalArgumentException("Model name and labels name must be matching.");
344        }
345    }
346
347    /**
348     * Parse RKNN name and return the name, width, height, and model type.
349     *
350     * <p>This is static as it is not dependent on the state of the class.
351     *
352     * @param modelName the name of the model
353     * @throws IllegalArgumentException if the model name does not follow the naming convention
354     * @return an array containing the name, width, height, and model type
355     */
356    public static String[] parseRKNNName(String modelName) {
357        Matcher modelMatcher = modelPattern.matcher(modelName);
358
359        if (!modelMatcher.matches()) {
360            throw new IllegalArgumentException(
361                    "Model name must follow the naming convention of name-widthResolution-heightResolution-modelType.rknn");
362        }
363
364        return new String[] {
365            modelMatcher.group(1), modelMatcher.group(2), modelMatcher.group(3), modelMatcher.group(4)
366        };
367    }
368}