山傘のプログラミング勉強日記

プログラミングに関する日記とどうでもよい雑記からなるブログです。

[Aizu Online Judge] 平方分割に関する問題 [DSL1A: Range Minimum Query, DSL1B: Range Sum Query]

平方分割

セグメント木を勉強しようと思い解説記事を読んでいたんですが、いまいち理解できませんでした。調べている最中に平方分割という手法を知り、自分でも少し理解できそうだったのでそれを試すことにしました。

平方分割の解説記事として下のサイトが分かりやすかったです。

kujira16.hateblo.jp

データを分割して整理するというものです。累積和とか群数列に少し似ている部分があると思いました。

平方分割に関してAOJの二つの問題を解きました。

DSL1A: Range Minimum Query

Range Minimum Query (RMQ) | データ構造ライブラリ | Aizu Online Judge

考え方

数列  A の大きさは  n と与えられますが、便宜上、次の大きさ  l に拡大します。

 m =\lceil \sqrt{n} \  \rceil

 l = m^2

また、バケット  B の大きさ  k k = \lceil  \dfrac{n}{m} \rceil です。バケットの中の数字は  m 個となっています。

例えば、

 A = \{1, 5, 2, 0, 6, 8\}

とすると、 m = 3 k = 2

 B_0 = \min \{1, 5, 2\} = 1

 B_1 = \min \{0, 6, 8\} = 0

となります。バケットの中に入っている値を直接参照できませんが、バケットの値 = バケットの代表値は取得できることがポイントです。

値の更新

要素  i x に変更する操作は、 A の値を変更した後に、  i が属するバケットを更新します。

 i が属するバケットの要素  b は、

 b = \lfloor \dfrac{i}{m} \rfloor

となります。逆に、バケットの要素  b に含まれる  A の要素  j は、

 bm \leq j \lt (b + 1)m

となります。

区間の最小値を探す

区間  [i, j] における最小値を探します。

調べる区間にあるバケットの要素が全て含まれていれば、そのバケットの代表値を返せばよいです。そうではないとき、 A の要素を一つずつ調べます。

コード

import java.util.Arrays;
import java.util.Scanner;

public class ProblemA2 {
    static int[]a;
    static int[]bucketMin;
    static int m;
    public static void main(String[] args) {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        int q = scan.nextInt();
        int d = 2147483647;
        m = (int)Math.ceil(Math.sqrt(n));
        a = new int[m * m];
        // バケットの個数
        int bucketNum = (int)Math.ceil((double)n / m);
        bucketMin = new int[bucketNum];
        Arrays.fill(a, d);
        Arrays.fill(bucketMin, d);
        for(int i = 0; i < q; i++) {
            int com = scan.nextInt();
            if(com == 0) {
                int x = scan.nextInt();
                int y = scan.nextInt();
                update(x, y);
            }else {
                int x = scan.nextInt();
                int y = scan.nextInt();
                int k  = find(x, y);
                System.out.println(k);
            }
        }
        scan.close();
    }
    // データとバケットを更新
    static void update(int i, int x) {
        a[i] = x;
        int k = i / m;
        int minVal = x;
        for(int j = k * m; j < (k + 1) * m; j++) {
            minVal = Math.min(minVal, a[j]);
        }
        bucketMin[k] = minVal;

    }
    // [i, j]の最小値を返す
    static int find(int i, int j) {
        int minVal = Integer.MAX_VALUE;
        int k = i;
        while(k <= j){
            if(k % m == 0 && k + m <= j) {
                minVal = Math.min(minVal, bucketMin[k / m]);
                k = k + m;
            }else {
                minVal = Math.min(minVal, a[k]);
                k++;
            }
        }
        return minVal;
    }
    static void disp(int[] a) {
        for(int i : a) {
            System.out.print(" " + i + " ");
        }
        System.out.println();
    }
}

DSL1B: Range Sum Query

上記とほとんど同じです。

コード

import java.util.Scanner;

public class ProblemB2 {
    static int[]a;
    static int[]bucketSum;
    static int m;
    public static void main(String[] args) {
        Scanner scan = new Scanner(System.in);
        int n = scan.nextInt();
        int q = scan.nextInt();
        m = (int)Math.ceil(Math.sqrt(n));
        a = new int[m * m];
        // バケットの個数
        int bucketNum = (int)Math.ceil((double)n / m);
        bucketSum = new int[bucketNum];
        for(int i = 0; i < q; i++) {
            int com = scan.nextInt();
            if(com == 0) {
                int x = scan.nextInt();
                int y = scan.nextInt();
                update(x-1, y);
            }else {
                int x = scan.nextInt();
                int y = scan.nextInt();
                int k  = getSum(x-1, y-1);
                System.out.println(k);
            }
        }
        scan.close();
    }
    // データとバケットを更新
    static void update(int i, int x) {
        a[i] += x;
        int k = i / m;
        int sumVal = 0;
        for(int j = k * m; j < (k + 1) * m; j++) {
            sumVal += a[j];
        }
        bucketSum[k] = sumVal;
    }
    static int getSum(int i, int j) {
        int sum = 0;
        int k = i;
        while(k <= j){
            if(k % m == 0 && k + m <= j) {
                sum += bucketSum[k / m];
                k = k + m;
            }else {
                sum += a[k];
                k++;
            }
        }
        return sum;
    }
}