001/*
002 *  Licensed under the Apache License, Version 2.0 (the "License");
003 *  you may not use this file except in compliance with the License.
004 *  You may obtain a copy of the License at
005 *
006 *       http://www.apache.org/licenses/LICENSE-2.0
007 *
008 *  Unless required by applicable law or agreed to in writing, software
009 *  distributed under the License is distributed on an "AS IS" BASIS,
010 *  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
011 *  See the License for the specific language governing permissions and
012 *  limitations under the License.
013 *  under the License.
014 */
015
016package org.apache.commons.imaging.formats.jpeg.decoder;
017
018import static org.apache.commons.imaging.common.BinaryFunctions.read2Bytes;
019import static org.apache.commons.imaging.common.BinaryFunctions.readBytes;
020
021import java.awt.image.BufferedImage;
022import java.awt.image.ColorModel;
023import java.awt.image.DataBuffer;
024import java.awt.image.DirectColorModel;
025import java.awt.image.Raster;
026import java.awt.image.WritableRaster;
027import java.io.ByteArrayInputStream;
028import java.io.IOException;
029import java.util.ArrayList;
030import java.util.Arrays;
031import java.util.List;
032import java.util.Properties;
033
034import org.apache.commons.imaging.ImageReadException;
035import org.apache.commons.imaging.color.ColorConversions;
036import org.apache.commons.imaging.common.BinaryFileParser;
037import org.apache.commons.imaging.common.bytesource.ByteSource;
038import org.apache.commons.imaging.formats.jpeg.JpegConstants;
039import org.apache.commons.imaging.formats.jpeg.JpegUtils;
040import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment;
041import org.apache.commons.imaging.formats.jpeg.segments.DhtSegment.HuffmanTable;
042import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment;
043import org.apache.commons.imaging.formats.jpeg.segments.DqtSegment.QuantizationTable;
044import org.apache.commons.imaging.formats.jpeg.segments.SofnSegment;
045import org.apache.commons.imaging.formats.jpeg.segments.SosSegment;
046
047public class JpegDecoder extends BinaryFileParser implements JpegUtils.Visitor {
048    /*
049     * JPEG is an advanced image format that takes significant computation to
050     * decode. Keep decoding fast: - Don't allocate memory inside loops,
051     * allocate it once and reuse. - Minimize calculations per pixel and per
052     * block (using lookup tables for YCbCr->RGB conversion doubled
053     * performance). - Math.round() is slow, use (int)(x+0.5f) instead for
054     * positive numbers.
055     */
056
057    private final DqtSegment.QuantizationTable[] quantizationTables = new DqtSegment.QuantizationTable[4];
058    private final DhtSegment.HuffmanTable[] huffmanDCTables = new DhtSegment.HuffmanTable[4];
059    private final DhtSegment.HuffmanTable[] huffmanACTables = new DhtSegment.HuffmanTable[4];
060    private SofnSegment sofnSegment;
061    private SosSegment sosSegment;
062    private final float[][] scaledQuantizationTables = new float[4][];
063    private BufferedImage image;
064    private ImageReadException imageReadException;
065    private IOException ioException;
066    private final int[] zz = new int[64];
067    private final int[] blockInt = new int[64];
068    private final float[] block = new float[64];
069
070    @Override
071    public boolean beginSOS() {
072        return true;
073    }
074
075    @Override
076    public void visitSOS(final int marker, final byte[] markerBytes, final byte[] imageData) {
077        final ByteArrayInputStream is = new ByteArrayInputStream(imageData);
078        try {
079            // read the scan header
080            final int segmentLength = read2Bytes("segmentLength", is,"Not a Valid JPEG File", getByteOrder());
081            final byte[] sosSegmentBytes = readBytes("SosSegment", is, segmentLength - 2, "Not a Valid JPEG File");
082            sosSegment = new SosSegment(marker, sosSegmentBytes);
083            // read the payload of the scan, this is the remainder of image data after the header
084            // the payload contains the entropy-encoded segments (or ECS) divided by RST markers
085            // or only one ECS if the entropy-encoded data is not divided by RST markers
086            // length of payload = length of image data - length of data already read
087            final int[] scanPayload = new int[imageData.length - segmentLength];
088            int payloadReadCount = 0;
089            while (payloadReadCount < scanPayload.length) {
090                scanPayload[payloadReadCount] = is.read();
091                payloadReadCount++;
092            }
093
094            int hMax = 0;
095            int vMax = 0;
096            for (int i = 0; i < sofnSegment.numberOfComponents; i++) {
097                hMax = Math.max(hMax,
098                        sofnSegment.getComponents(i).horizontalSamplingFactor);
099                vMax = Math.max(vMax,
100                        sofnSegment.getComponents(i).verticalSamplingFactor);
101            }
102            final int hSize = 8 * hMax;
103            final int vSize = 8 * vMax;
104
105            final int xMCUs = (sofnSegment.width + hSize - 1) / hSize;
106            final int yMCUs = (sofnSegment.height + vSize - 1) / vSize;
107            final Block[] mcu = allocateMCUMemory();
108            final Block[] scaledMCU = new Block[mcu.length];
109            for (int i = 0; i < scaledMCU.length; i++) {
110                scaledMCU[i] = new Block(hSize, vSize);
111            }
112            final int[] preds = new int[sofnSegment.numberOfComponents];
113            ColorModel colorModel;
114            WritableRaster raster;
115            switch (sofnSegment.numberOfComponents) {
116            case 4:
117                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00, 0x000000ff);
118                final int[] bandMasks = new int[] { 0x00ff0000, 0x0000ff00, 0x000000ff };
119                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT, sofnSegment.width, sofnSegment.height, bandMasks, null);
120                break;
121            case 3:
122                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
123                        0x000000ff);
124                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
125                        sofnSegment.width, sofnSegment.height, new int[] {
126                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
127                break;
128            case 1:
129                colorModel = new DirectColorModel(24, 0x00ff0000, 0x0000ff00,
130                        0x000000ff);
131                raster = Raster.createPackedRaster(DataBuffer.TYPE_INT,
132                        sofnSegment.width, sofnSegment.height, new int[] {
133                                0x00ff0000, 0x0000ff00, 0x000000ff }, null);
134                // FIXME: why do images come out too bright with CS_GRAY?
135                // colorModel = new ComponentColorModel(
136                // ColorSpace.getInstance(ColorSpace.CS_GRAY), false, true,
137                // Transparency.OPAQUE, DataBuffer.TYPE_BYTE);
138                // raster = colorModel.createCompatibleWritableRaster(
139                // sofnSegment.width, sofnSegment.height);
140                break;
141            default:
142                throw new ImageReadException(sofnSegment.numberOfComponents
143                        + " components are invalid or unsupported");
144            }
145            final DataBuffer dataBuffer = raster.getDataBuffer();
146
147            final JpegInputStream[] bitInputStreams = splitByRstMarkers(scanPayload);
148            int bitInputStreamCount = 0;
149            JpegInputStream bitInputStream = bitInputStreams[0];
150
151            for (int y1 = 0; y1 < vSize * yMCUs; y1 += vSize) {
152                for (int x1 = 0; x1 < hSize * xMCUs; x1 += hSize) {
153                    // Provide the next interval if an interval is read until it's end
154                    // as long there are unread intervals available
155                    if (!bitInputStream.hasNext()) {
156                        bitInputStreamCount++;
157                        if (bitInputStreamCount < bitInputStreams.length) {
158                            bitInputStream = bitInputStreams[bitInputStreamCount];
159                        }
160                    }
161
162                    readMCU(bitInputStream, preds, mcu);
163                    rescaleMCU(mcu, hSize, vSize, scaledMCU);
164                    int srcRowOffset = 0;
165                    int dstRowOffset = y1 * sofnSegment.width + x1;
166                    for (int y2 = 0; y2 < vSize && y1 + y2 < sofnSegment.height; y2++) {
167                        for (int x2 = 0; x2 < hSize
168                                && x1 + x2 < sofnSegment.width; x2++) {
169                            if (scaledMCU.length == 4) {
170                                final int C = scaledMCU[0].samples[srcRowOffset + x2];
171                                final int M = scaledMCU[1].samples[srcRowOffset + x2];
172                                final int Y = scaledMCU[2].samples[srcRowOffset + x2];
173                                final int K = scaledMCU[3].samples[srcRowOffset + x2];
174                                final int rgb = ColorConversions.convertCMYKtoRGB(C, M, Y, K);
175                                dataBuffer.setElem(dstRowOffset + x2, rgb);
176                            } else if (scaledMCU.length == 3) {
177                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
178                                final int Cb = scaledMCU[1].samples[srcRowOffset + x2];
179                                final int Cr = scaledMCU[2].samples[srcRowOffset + x2];
180                                final int rgb = YCbCrConverter.convertYCbCrToRGB(Y,
181                                        Cb, Cr);
182                                dataBuffer.setElem(dstRowOffset + x2, rgb);
183                            } else if (mcu.length == 1) {
184                                final int Y = scaledMCU[0].samples[srcRowOffset + x2];
185                                dataBuffer.setElem(dstRowOffset + x2, (Y << 16)
186                                        | (Y << 8) | Y);
187                            } else {
188                                throw new ImageReadException(
189                                        "Unsupported JPEG with " + mcu.length
190                                                + " components");
191                            }
192                        }
193                        srcRowOffset += hSize;
194                        dstRowOffset += sofnSegment.width;
195                    }
196                }
197            }
198            image = new BufferedImage(colorModel, raster,
199                    colorModel.isAlphaPremultiplied(), new Properties());
200            // byte[] remainder = super.getStreamBytes(is);
201            // for (int i = 0; i < remainder.length; i++)
202            // {
203            // System.out.println("" + i + " = " +
204            // Integer.toHexString(remainder[i]));
205            // }
206        } catch (final ImageReadException imageReadEx) {
207            imageReadException = imageReadEx;
208        } catch (final IOException ioEx) {
209            ioException = ioEx;
210        } catch (final RuntimeException ex) {
211            // Corrupt images can throw NPE and IOOBE
212            imageReadException = new ImageReadException("Error parsing JPEG",ex);
213        }
214    }
215
216    @Override
217    public boolean visitSegment(final int marker, final byte[] markerBytes,
218            final int segmentLength, final byte[] segmentLengthBytes, final byte[] segmentData)
219            throws ImageReadException, IOException {
220        final int[] sofnSegments = {
221                JpegConstants.SOF0_MARKER,
222                JpegConstants.SOF1_MARKER,
223                JpegConstants.SOF2_MARKER,
224                JpegConstants.SOF3_MARKER,
225                JpegConstants.SOF5_MARKER,
226                JpegConstants.SOF6_MARKER,
227                JpegConstants.SOF7_MARKER,
228                JpegConstants.SOF9_MARKER,
229                JpegConstants.SOF10_MARKER,
230                JpegConstants.SOF11_MARKER,
231                JpegConstants.SOF13_MARKER,
232                JpegConstants.SOF14_MARKER,
233                JpegConstants.SOF15_MARKER,
234        };
235
236        if (Arrays.binarySearch(sofnSegments, marker) >= 0) {
237            if (marker != JpegConstants.SOF0_MARKER) {
238                throw new ImageReadException("Only sequential, baseline JPEGs "
239                        + "are supported at the moment");
240            }
241            sofnSegment = new SofnSegment(marker, segmentData);
242        } else if (marker == JpegConstants.DQT_MARKER) {
243            final DqtSegment dqtSegment = new DqtSegment(marker, segmentData);
244            for (final QuantizationTable table : dqtSegment.quantizationTables) {
245                if (0 > table.destinationIdentifier
246                        || table.destinationIdentifier >= quantizationTables.length) {
247                    throw new ImageReadException(
248                            "Invalid quantization table identifier "
249                                    + table.destinationIdentifier);
250                }
251                quantizationTables[table.destinationIdentifier] = table;
252                final int[] quantizationMatrixInt = new int[64];
253                ZigZag.zigZagToBlock(table.getElements(), quantizationMatrixInt);
254                final float[] quantizationMatrixFloat = new float[64];
255                for (int j = 0; j < 64; j++) {
256                    quantizationMatrixFloat[j] = quantizationMatrixInt[j];
257                }
258                Dct.scaleDequantizationMatrix(quantizationMatrixFloat);
259                scaledQuantizationTables[table.destinationIdentifier] = quantizationMatrixFloat;
260            }
261        } else if (marker == JpegConstants.DHT_MARKER) {
262            final DhtSegment dhtSegment = new DhtSegment(marker, segmentData);
263            for (final HuffmanTable table : dhtSegment.huffmanTables) {
264                DhtSegment.HuffmanTable[] tables;
265                if (table.tableClass == 0) {
266                    tables = huffmanDCTables;
267                } else if (table.tableClass == 1) {
268                    tables = huffmanACTables;
269                } else {
270                    throw new ImageReadException("Invalid huffman table class "
271                            + table.tableClass);
272                }
273                if (0 > table.destinationIdentifier
274                        || table.destinationIdentifier >= tables.length) {
275                    throw new ImageReadException(
276                            "Invalid huffman table identifier "
277                                    + table.destinationIdentifier);
278                }
279                tables[table.destinationIdentifier] = table;
280            }
281        }
282        return true;
283    }
284
285    private void rescaleMCU(final Block[] dataUnits, final int hSize, final int vSize, final Block[] ret) {
286        for (int i = 0; i < dataUnits.length; i++) {
287            final Block dataUnit = dataUnits[i];
288            if (dataUnit.width == hSize && dataUnit.height == vSize) {
289                System.arraycopy(dataUnit.samples, 0, ret[i].samples, 0, hSize
290                        * vSize);
291            } else {
292                final int hScale = hSize / dataUnit.width;
293                final int vScale = vSize / dataUnit.height;
294                if (hScale == 2 && vScale == 2) {
295                    int srcRowOffset = 0;
296                    int dstRowOffset = 0;
297                    for (int y = 0; y < dataUnit.height; y++) {
298                        for (int x = 0; x < hSize; x++) {
299                            final int sample = dataUnit.samples[srcRowOffset + (x >> 1)];
300                            ret[i].samples[dstRowOffset + x] = sample;
301                            ret[i].samples[dstRowOffset + hSize + x] = sample;
302                        }
303                        srcRowOffset += dataUnit.width;
304                        dstRowOffset += 2 * hSize;
305                    }
306                } else {
307                    // FIXME: optimize
308                    int dstRowOffset = 0;
309                    for (int y = 0; y < vSize; y++) {
310                        for (int x = 0; x < hSize; x++) {
311                            ret[i].samples[dstRowOffset + x] = dataUnit.samples[(y / vScale)
312                                    * dataUnit.width + (x / hScale)];
313                        }
314                        dstRowOffset += hSize;
315                    }
316                }
317            }
318        }
319    }
320
321    private Block[] allocateMCUMemory() throws ImageReadException {
322        final Block[] mcu = new Block[sosSegment.numberOfComponents];
323        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
324            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
325            SofnSegment.Component frameComponent = null;
326            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
327                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
328                    frameComponent = sofnSegment.getComponents(j);
329                    break;
330                }
331            }
332            if (frameComponent == null) {
333                throw new ImageReadException("Invalid component");
334            }
335            final Block fullBlock = new Block(
336                    8 * frameComponent.horizontalSamplingFactor,
337                    8 * frameComponent.verticalSamplingFactor);
338            mcu[i] = fullBlock;
339        }
340        return mcu;
341    }
342
343    private void readMCU(final JpegInputStream is, final int[] preds, final Block[] mcu)
344            throws ImageReadException {
345        for (int i = 0; i < sosSegment.numberOfComponents; i++) {
346            final SosSegment.Component scanComponent = sosSegment.getComponents(i);
347            SofnSegment.Component frameComponent = null;
348            for (int j = 0; j < sofnSegment.numberOfComponents; j++) {
349                if (sofnSegment.getComponents(j).componentIdentifier == scanComponent.scanComponentSelector) {
350                    frameComponent = sofnSegment.getComponents(j);
351                    break;
352                }
353            }
354            if (frameComponent == null) {
355                throw new ImageReadException("Invalid component");
356            }
357            final Block fullBlock = mcu[i];
358            for (int y = 0; y < frameComponent.verticalSamplingFactor; y++) {
359                for (int x = 0; x < frameComponent.horizontalSamplingFactor; x++) {
360                    Arrays.fill(zz, 0);
361                    // page 104 of T.81
362                    final int t = decode(
363                            is,
364                            huffmanDCTables[scanComponent.dcCodingTableSelector]);
365                    int diff = receive(t, is);
366                    diff = extend(diff, t);
367                    zz[0] = preds[i] + diff;
368                    preds[i] = zz[0];
369
370                    // "Decode_AC_coefficients", figure F.13, page 106 of T.81
371                    int k = 1;
372                    while (true) {
373                        final int rs = decode(
374                                is,
375                                huffmanACTables[scanComponent.acCodingTableSelector]);
376                        final int ssss = rs & 0xf;
377                        final int rrrr = rs >> 4;
378                        final int r = rrrr;
379
380                        if (ssss == 0) {
381                            if (r != 15) {
382                                break;
383                            }
384                            k += 16;
385                        } else {
386                            k += r;
387
388                            // "Decode_ZZ(k)", figure F.14, page 107 of T.81
389                            zz[k] = receive(ssss, is);
390                            zz[k] = extend(zz[k], ssss);
391
392                            if (k == 63) {
393                                break;
394                            }
395                            k++;
396                        }
397                    }
398
399                    final int shift = (1 << (sofnSegment.precision - 1));
400                    final int max = (1 << sofnSegment.precision) - 1;
401
402                    final float[] scaledQuantizationTable = scaledQuantizationTables[frameComponent.quantTabDestSelector];
403                    ZigZag.zigZagToBlock(zz, blockInt);
404                    for (int j = 0; j < 64; j++) {
405                        block[j] = blockInt[j] * scaledQuantizationTable[j];
406                    }
407                    Dct.inverseDCT8x8(block);
408
409                    int dstRowOffset = 8 * y * 8
410                            * frameComponent.horizontalSamplingFactor + 8 * x;
411                    int srcNext = 0;
412                    for (int yy = 0; yy < 8; yy++) {
413                        for (int xx = 0; xx < 8; xx++) {
414                            float sample = block[srcNext++];
415                            sample += shift;
416                            int result;
417                            if (sample < 0) {
418                                result = 0;
419                            } else if (sample > max) {
420                                result = max;
421                            } else {
422                                result = fastRound(sample);
423                            }
424                            fullBlock.samples[dstRowOffset + xx] = result;
425                        }
426                        dstRowOffset += 8 * frameComponent.horizontalSamplingFactor;
427                    }
428                }
429            }
430        }
431    }
432
433    /**
434     * Returns an array of JpegInputStream where each field contains the JpegInputStream
435     * for one interval.
436     * @param scanPayload array to read intervals from
437     * @return JpegInputStreams for all intervals, at least one stream is always provided
438     */
439    static JpegInputStream[] splitByRstMarkers(final int[] scanPayload) {
440        final List<Integer> intervalStarts = getIntervalStartPositions(scanPayload);
441        // get number of intervals in payload to init an array of appropriate length
442        final int intervalCount = intervalStarts.size();
443        final JpegInputStream[] streams = new JpegInputStream[intervalCount];
444        for (int i = 0; i < intervalCount; i++) {
445            final int from = intervalStarts.get(i);
446            int to;
447            if (i < intervalCount - 1) {
448                // because each restart marker needs two bytes the end of
449                // this interval is two bytes before the next interval starts
450                to = intervalStarts.get(i + 1) - 2;
451            } else { // the last interval ends with the array
452                to = scanPayload.length;
453            }
454            final int[] interval = Arrays.copyOfRange(scanPayload, from, to);
455            streams[i] = new JpegInputStream(interval);
456        }
457        return streams;
458    }
459
460    /**
461     * Returns the positions of where each interval in the provided array starts. The number
462     * of start positions is also the count of intervals while the number of restart markers
463     * found is equal to the number of start positions minus one (because restart markers
464     * are between intervals).
465     *
466     * @param scanPayload array to examine
467     * @return the start positions
468     */
469    static List<Integer> getIntervalStartPositions(final int[] scanPayload) {
470        final List<Integer> intervalStarts = new ArrayList<>();
471        intervalStarts.add(0);
472        boolean foundFF = false;
473        boolean foundD0toD7 = false;
474        int pos = 0;
475        while (pos < scanPayload.length) {
476            if (foundFF) {
477                // found 0xFF D0 .. 0xFF D7 => RST marker
478                if (scanPayload[pos] >= (0xff & JpegConstants.RST0_MARKER) &&
479                    scanPayload[pos] <= (0xff & JpegConstants.RST7_MARKER)) {
480                    foundD0toD7 = true;
481                } else { // found 0xFF followed by something else => no RST marker
482                    foundFF = false;
483                }
484            }
485
486            if (scanPayload[pos] == 0xFF) {
487                foundFF = true;
488            }
489
490            // true if one of the RST markers was found
491            if (foundFF && foundD0toD7) {
492                // we need to add the position after the current position because
493                // we had already read 0xFF and are now at 0xDn
494                intervalStarts.add(pos + 1);
495                foundFF = foundD0toD7 = false;
496            }
497            pos++;
498        }
499        return intervalStarts;
500    }
501
502    private static int fastRound(final float x) {
503        return (int) (x + 0.5f);
504    }
505
506    private int extend(int v, final int t) {
507        // "EXTEND", section F.2.2.1, figure F.12, page 105 of T.81
508        int vt = (1 << (t - 1));
509        if (v < vt) {
510            vt = (-1 << t) + 1;
511            v += vt;
512        }
513        return v;
514    }
515
516    private int receive(final int ssss, final JpegInputStream is) throws ImageReadException {
517        // "RECEIVE", section F.2.2.4, figure F.17, page 110 of T.81
518        int i = 0;
519        int v = 0;
520        while (i != ssss) {
521            i++;
522            v = (v << 1) + is.nextBit();
523        }
524        return v;
525    }
526
527    private int decode(final JpegInputStream is, final DhtSegment.HuffmanTable huffmanTable)
528            throws ImageReadException {
529        // "DECODE", section F.2.2.3, figure F.16, page 109 of T.81
530        int i = 1;
531        int code = is.nextBit();
532        while (code > huffmanTable.getMaxCode(i)) {
533            i++;
534            code = (code << 1) | is.nextBit();
535        }
536        int j = huffmanTable.getValPtr(i);
537        j += code - huffmanTable.getMinCode(i);
538        return huffmanTable.getHuffVal(j);
539    }
540
541    public BufferedImage decode(final ByteSource byteSource) throws IOException,
542            ImageReadException {
543        final JpegUtils jpegUtils = new JpegUtils();
544        jpegUtils.traverseJFIF(byteSource, this);
545        if (imageReadException != null) {
546            throw imageReadException;
547        }
548        if (ioException != null) {
549            throw ioException;
550        }
551        return image;
552    }
553}