LocalDeepAutopilot.java
package me.schawe.multijsnake.snake.ai;
import me.schawe.multijsnake.snake.GameState;
import me.schawe.multijsnake.snake.Move;
import me.schawe.multijsnake.snake.Snake;
import me.schawe.multijsnake.snake.TrainingState;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
public class LocalDeepAutopilot extends KerasModelAutopilot {
public LocalDeepAutopilot(String pathToModel, boolean isFunctional) {
super(pathToModel, isFunctional);
}
@Override
public Move suggest(GameState gameState, Snake snake) {
// infer
var state = new TrainingState(gameState).vector(snake.getId());
INDArray input = Nd4j.create(state).reshape(1, state.size());
INDArray output;
if(modelSequential != null) {
output = modelSequential.output(input);
} else if(modelFunctional != null) {
output = modelFunctional.output(input)[0];
} else {
gameState.kill(snake.getId());
throw new RuntimeException("failed to load model `" + pathToModel + "`");
}
int action = output.ravel().argMax().getInt(0);
return TrainingState.relativeAction2Move(action, snake.getLastHeadDirection());
}
}