画像異常検知用の学習データ自動生成ツールを作った

初めに

MVTechのような、異常検知用のデータ生成ツールをUnityで作った。ざっくりいうと、正常な部品を模擬した3Dモデルに対して以下の操作を繰り替えして、異常データを生成する。

  1. 異常データを模擬するため、3Dモデルのメッシュの一部をへこませる。
  2. 画像を保存する。

異常データは以下のような感じ。メッシュの一部をへこませるので、ある程度頂点数があるカプセルを使用している。また、柄もある程度違いが出るように、木目調のマテリアルを貼り付けている。

ちなみにスクリプトはほぼChatGPTで生成している。

異常なオブジェクトの生成

まずは先ほどの写真のようなオブジェクト(カプセルでも何でもOK)を作成する。

そして、以下のスクリプトをアタッチする。これでこのオブジェクトは生成時にへこむようになる。また、このスクリプトを無効化しておけば、正常画像を生成できる。deformationAmountで、へこむ量を調整できる。

using System.Collections.Generic;
using UnityEngine;

[RequireComponent(typeof(MeshFilter))]
public class RandomVertexDeformer : MonoBehaviour
{
    private Mesh mesh;
    public float deformationAmount = -0.5f;  // へこむ量を設定

    void Start()
    {
        mesh = GetComponent<MeshFilter>().mesh;
        DeformRandomVertices(5);
    }

    void DeformRandomVertices(int count)
    {
        if (mesh == null || mesh.vertexCount == 0) return;

        Vector3[] vertices = mesh.vertices;
        HashSet<int> indicesSet = new HashSet<int>();

        while (indicesSet.Count < count)
        {
            int randomIndex = Random.Range(0, mesh.vertexCount);
            indicesSet.Add(randomIndex);
        }

        List<int> selectedIndices = new List<int>(indicesSet);
        int vertexToDeformIndex = selectedIndices[Random.Range(0, selectedIndices.Count)]; // 5つの選択された頂点の中から1つをランダムに選択

        // へこませる処理
        Vector3 normal = mesh.normals[vertexToDeformIndex];
        vertices[vertexToDeformIndex] += normal * deformationAmount;

        mesh.vertices = vertices;
        mesh.RecalculateNormals();  // 法線を再計算して照明が正しく動作するようにする
    }
}

写真保存

空のオブジェクトを作成して、以下のスクリプトをアタッチする。このスクリプトは、以下を繰り返し行ってくれる。

  1. オブジェクトを生成(位置、回転はランダム)
  2. スクリーンショットを保存
  3. オブジェクトを削除
using UnityEngine;
using System.Collections;

public class ObjectSpawnerAndCameraCapture : MonoBehaviour
{
    public GameObject objectPrefab;        // インスタンス化するオブジェクトのPrefab
    public Transform spawnPosition;        // オブジェクトを生成する位置
    public Camera captureCamera;           // 画像をキャプチャするカメラ
    public int totalCaptures = 10;         // 保存する画像の合計数
    public float timeBetweenCaptures = 2f; // 画像をキャプチャする間隔(秒)

    private int currentCaptureCount = 0;
    private string basePath;
    private GameObject spawnedObject;      // 生成されたオブジェクトの参照
    
    public Vector3 moveRange = new Vector3(1, 1, 1);  // 移動する最大の範囲
    public Vector3 rotationRange = new Vector3(360, 360, 360);  // 各軸における最大の回転角度

    private void Start()
    {
        basePath = Application.persistentDataPath + "/screenshot";
        Debug.Log(basePath);
        StartCoroutine(CaptureRoutine());
    }

    private IEnumerator CaptureRoutine()
    {
        while (currentCaptureCount < totalCaptures)
        {
            Debug.Log(moveRange);
            // オブジェクトの生成
            spawnedObject = Instantiate(objectPrefab, spawnPosition.position, spawnPosition.rotation);
            
            // ランダムに移動
            float offsetX = Random.Range(-moveRange.x, moveRange.x);
            float offsetY = Random.Range(-moveRange.y, moveRange.y);
            float offsetZ = Random.Range(-moveRange.z, moveRange.z);

            spawnedObject.transform.position  += new Vector3(offsetX, offsetY, offsetZ);

            float randomRotationX = Random.Range(0, rotationRange.x);
            float randomRotationY = Random.Range(0, rotationRange.y);
            float randomRotationZ = Random.Range(0, rotationRange.z);// + spawnPosition.rotation.z;

            spawnedObject.transform.eulerAngles = new Vector3(randomRotationX, 0, 90);
            

            yield return new WaitForSeconds(1f);  // 必要に応じて待機時間を調整してください

            // カメラ画像の保存
            CaptureCameraImage();

            // オブジェクトの削除
            Destroy(spawnedObject);

            currentCaptureCount++;
            yield return new WaitForSeconds(timeBetweenCaptures - 1f);
        }
    }

    private void CaptureCameraImage()
    {
        RenderTexture renderTexture = new RenderTexture(captureCamera.pixelWidth, captureCamera.pixelHeight, 24);
        captureCamera.targetTexture = renderTexture;

        // カメラの画像をキャプチャする
        captureCamera.Render();
        RenderTexture.active = renderTexture;

        Texture2D screenshot = new Texture2D(captureCamera.pixelWidth, captureCamera.pixelHeight, TextureFormat.RGB24, false);
        screenshot.ReadPixels(new Rect(0, 0, captureCamera.pixelWidth, captureCamera.pixelHeight), 0, 0);
        screenshot.Apply();

        byte[] bytes = screenshot.EncodeToPNG();
        System.IO.File.WriteAllBytes(basePath + currentCaptureCount + ".png", bytes);

        Destroy(screenshot);

        captureCamera.targetTexture = null;
        RenderTexture.active = null;
        Destroy(renderTexture);
    }
}

objectPrefabにさっき作成した3Dモデル、spawnPositionにオブジェクトを生成する位置(カメラ前になるようにしておく)、captureCameraにカメラ、totalCapturesにキャプチャする枚数をセットする。

実行

準備ができたら実行すると、勝手に画像が生成されていく。

テスト

異常検知を試してみた結果は以下のとおり。(真ん中のマスク画像は適当)ちゃんと異常箇所が着目されていることがわかる。