001/*
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017package org.apache.commons.imaging.palette;
018
019import java.awt.image.BufferedImage;
020import java.util.ArrayList;
021import java.util.HashMap;
022import java.util.List;
023import java.util.Map;
024
025import org.apache.commons.imaging.ImageWriteException;
026import org.apache.commons.imaging.internal.Debug;
027
028public class MedianCutQuantizer {
029    private final boolean ignoreAlpha;
030
031    public MedianCutQuantizer(final boolean ignoreAlpha) {
032        this.ignoreAlpha = ignoreAlpha;
033    }
034
035    private Map<Integer, ColorCount> groupColors1(final BufferedImage image, final int max,
036            final int mask) {
037        final Map<Integer, ColorCount> colorMap = new HashMap<>();
038
039        final int width = image.getWidth();
040        final int height = image.getHeight();
041
042        final int[] row = new int[width];
043        for (int y = 0; y < height; y++) {
044            image.getRGB(0, y, width, 1, row, 0, width);
045            for (int x = 0; x < width; x++) {
046                int argb = row[x];
047
048                if (ignoreAlpha) {
049                    argb &= 0xffffff;
050                }
051                argb &= mask;
052
053                ColorCount color = colorMap.get(argb);
054                if (color == null) {
055                    color = new ColorCount(argb);
056                    colorMap.put(argb, color);
057                    if (colorMap.size() > max) {
058                        return null;
059                    }
060                }
061                color.count++;
062            }
063        }
064
065        return colorMap;
066    }
067
068    public Map<Integer, ColorCount> groupColors(final BufferedImage image, final int maxColors) {
069        final int max = Integer.MAX_VALUE;
070
071        for (int i = 0; i < 8; i++) {
072            int mask = 0xff & (0xff << i);
073            mask = mask | (mask << 8) | (mask << 16) | (mask << 24);
074
075            Debug.debug("mask(" + i + "): " + mask + " (" + Integer.toHexString(mask) + ")");
076
077            final Map<Integer, ColorCount> result = groupColors1(image, max, mask);
078            if (result != null) {
079                return result;
080            }
081        }
082        throw new Error("");
083    }
084
085    public Palette process(final BufferedImage image, final int maxColors,
086            final MedianCut medianCut)
087            throws ImageWriteException {
088        final Map<Integer, ColorCount> colorMap = groupColors(image, maxColors);
089
090        final int discreteColors = colorMap.size();
091        if (discreteColors <= maxColors) {
092            Debug.debug("lossless palette: " + discreteColors);
093
094            final int[] palette = new int[discreteColors];
095            final List<ColorCount> colorCounts = new ArrayList<>(
096                    colorMap.values());
097
098            for (int i = 0; i < colorCounts.size(); i++) {
099                final ColorCount colorCount = colorCounts.get(i);
100                palette[i] = colorCount.argb;
101                if (ignoreAlpha) {
102                    palette[i] |= 0xff000000;
103                }
104            }
105
106            return new SimplePalette(palette);
107        }
108
109        Debug.debug("discrete colors: " + discreteColors);
110
111        final List<ColorGroup> colorGroups = new ArrayList<>();
112        final ColorGroup root = new ColorGroup(new ArrayList<>(colorMap.values()), ignoreAlpha);
113        colorGroups.add(root);
114
115        while (colorGroups.size() < maxColors) {
116            if (!medianCut.performNextMedianCut(colorGroups, ignoreAlpha)) {
117                break;
118            }
119        }
120
121        final int paletteSize = colorGroups.size();
122        Debug.debug("palette size: " + paletteSize);
123
124        final int[] palette = new int[paletteSize];
125
126        for (int i = 0; i < colorGroups.size(); i++) {
127            final ColorGroup colorGroup = colorGroups.get(i);
128
129            palette[i] = colorGroup.getMedianValue();
130
131            colorGroup.paletteIndex = i;
132
133            if (colorGroup.getColorCounts().isEmpty()) {
134                throw new ImageWriteException("empty color_group: "
135                        + colorGroup);
136            }
137        }
138
139        if (paletteSize > discreteColors) {
140            throw new ImageWriteException("palette_size > discrete_colors");
141        }
142
143        return new MedianCutPalette(root, palette);
144    }
145}