All files / js/widgets Ai.tsx

85.71% Statements 42/49
66.67% Branches 8/12
73.33% Functions 11/15
88.64% Lines 39/44

Press n or j to go to the next uncovered block, b, p or k for the previous block.

1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142                      15x                                       7x   84x   7x     7x 7x       7x   7x   7x 4x     232x   3x   3x 2x       7x 3x   3x 3x                   232x       232x 232x 107x 107x   107x 107x   125x 125x   125x 125x   232x   232x 232x       235x 235x     235x 235x     235x             7x                             2x                        
import React, {useEffect, useRef, useState} from "react";
import {defaultVisualizationOptions, draw} from "../visualization/canvasDraw";
import {Container, Grid, Stack} from "@mui/material";
// @ts-ignore
import jsYaml from "js-yaml";
// @ts-ignore
import rawAiOptions from "../../resources/models/strategies.yaml"
import JsGameState from "../SnakeLogic/JsGameState";
import AddAutopilot from "./AddAutopilot";
// import FieldSizeSelector from "./FieldSizeSelector";
// import tensorflow dynamically (in a separate chuck thanks to webpack), since it is large and only needed for one route
const {tensor, loadLayersModel, engine} = await import("@tensorflow/tfjs");
// this one is just for the type and is hopefully dead-code-eliminated from the bundle
import {LayersModel} from "@tensorflow/tfjs";
 
export type AiOption = {
    id: string,
    label: string,
    description: string,
    model_path_js: string,
    input: string,
    mode: string
}
 
type Model = {
    options: AiOption
    model: LayersModel|undefined
}
 
export default function Ai() {
    // load all models from the yaml ...
    const aiJsOptionsAll = jsYaml.load(rawAiOptions);
    // ... but only offer those, which have a model in the correct format
    const aiJsOptions: AiOption[] = aiJsOptionsAll.filter((obj: AiOption) => "model_path_js" in obj);
 
    const visOpts = defaultVisualizationOptions;
 
    // game is a mutable object, it will mutate itself and will not trigger re-rendering
    const [game, _] = useState(new JsGameState(10, 10, (score) => gameOver(score)));
    const [model, setModel] = useState<Model>({
        options: aiJsOptions[aiJsOptions.length - 1],
        model: undefined
    });
    const canvasRef = useRef<HTMLCanvasElement|null>(null);
 
    useEffect(() => {
        // if we do not have a model, we don't need to try to step
        if(model.model === undefined) {
            return
        }
 
        const refresh = window.setInterval(() => step(), 30);
 
        draw(getContext(), game, visOpts)
 
        return () => {
            window.clearInterval(refresh);
        };
    })
 
    useEffect(() => {
        loadLayersModel(model.options.model_path_js).then(m => setModel(prevState => {
            // if there is an old model, dispose it first
            prevState.model && prevState.model.layers.forEach(l => l.dispose());
            return {options: model.options, model: m};
        }));
    }, [model.options])
 
    // function newGame(width: number, height: number) {
    //     const game = new JsGameState(width, height, (score) => gameOver(score));
    //     setGame(game);
    // }
 
    function step() {
        Iif(model.model === undefined) {
            return
        }
 
        engine().startScope()
        if(model.options.input === "global") {
            const state = tensor([game.trainingBitmap()]);
            const out = model.model.predict(state);
            // @ts-ignore
            const action = out[0].argMax(1).arraySync()[0];
            game.absoluteAction2Move(action);
        } else {
            const state = tensor([game.trainingState()]);
            const out = model.model.predict(state);
            // @ts-ignore
            const action = out[0].argMax(1).arraySync()[0];
            game.relativeAction2Move(action);
        }
        engine().endScope();
 
        game.update();
        draw(getContext(), game, visOpts)
    }
 
    function getContext(): CanvasRenderingContext2D {
        const canvas = canvasRef.current;
        Iif(canvas === null) {
            throw "Canvas failed to construct";
        }
        const context = canvas.getContext('2d');
        Iif(context === null) {
            throw "Failed to get the Context of the constructed canvas";
        }
        return context;
    }
 
    function gameOver(score: number) {
        console.log(score)
    }
 
    return (
        <Container maxWidth="lg">
            <Grid container spacing={4} pt={4} justifyContent="space-around" alignItems="flex-start">
                <Grid item xs={12} lg={6}>
                    <canvas
                        ref={canvasRef}
                        width={game.width * visOpts.scale}
                        height={game.height * visOpts.scale}
                        id={"snakeCanvas"}
                    />
                </Grid>
                <Grid item xs={12} lg={6}>
                    <Stack spacing={2}>
                        <AddAutopilot
                            onCommit={obj => setModel({options: obj, model: undefined})}
                            onChange={obj => setModel({options: obj, model: undefined})}
                            aiOptions={aiJsOptions}
                            submitText={"Change AI"}
                            width={"100%"}
                            defaultValue={model.options}
                            commitMode={false}
                        />
                    </Stack>
                </Grid>
            </Grid>
        </Container>
    )
}