001/**
002 * Licensed to the Apache Software Foundation (ASF) under one
003 * or more contributor license agreements.  See the NOTICE file
004 * distributed with this work for additional information
005 * regarding copyright ownership.  The ASF licenses this file
006 * to you under the Apache License, Version 2.0 (the
007 * "License"); you may not use this file except in compliance
008 * with the License.  You may obtain a copy of the License at
009 *
010 *     http://www.apache.org/licenses/LICENSE-2.0
011 *
012 * Unless required by applicable law or agreed to in writing, software
013 * distributed under the License is distributed on an "AS IS" BASIS,
014 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015 * See the License for the specific language governing permissions and
016 * limitations under the License.
017 */
018package org.apache.commons.crypto.stream;
019
020import java.io.IOException;
021import java.io.InputStream;
022import java.nio.ByteBuffer;
023import java.nio.channels.ReadableByteChannel;
024import java.security.GeneralSecurityException;
025import java.util.Properties;
026
027import javax.crypto.Cipher;
028import javax.crypto.spec.IvParameterSpec;
029import javax.crypto.spec.SecretKeySpec;
030
031import org.apache.commons.crypto.cipher.CryptoCipher;
032import org.apache.commons.crypto.cipher.CryptoCipherFactory;
033import org.apache.commons.crypto.stream.input.ChannelInput;
034import org.apache.commons.crypto.stream.input.Input;
035import org.apache.commons.crypto.stream.input.StreamInput;
036import org.apache.commons.crypto.utils.Utils;
037
038/**
039 * <p>
040 * CtrCryptoInputStream decrypts data. AES CTR mode is required in order to
041 * ensure that the plain text and cipher text have a 1:1 mapping. CTR crypto
042 * stream has stream characteristic which is useful for implement features like
043 * random seek. The decryption is buffer based. The key points of the decryption
044 * are (1) calculating the counter and (2) padding through stream position:
045 * </p>
046 * <p>
047 * counter = base + pos/(algorithm blocksize); padding = pos%(algorithm
048 * blocksize);
049 * </p>
050 * The underlying stream offset is maintained as state. It is not thread-safe.
051 */
052public class CtrCryptoInputStream extends CryptoInputStream {
053    /**
054     * Underlying stream offset
055     */
056    private long streamOffset = 0;
057
058    /**
059     * The initial IV.
060     */
061    private final byte[] initIV;
062
063    /**
064     * Initialization vector for the cipher.
065     */
066    private final byte[] iv;
067
068    /**
069     * Padding = pos%(algorithm blocksize); Padding is put into
070     * {@link #inBuffer} before any other data goes in. The purpose of padding
071     * is to put the input data at proper position.
072     */
073    private byte padding;
074
075    /**
076     * Flag to mark whether the cipher has been reset
077     */
078    private boolean cipherReset = false;
079
080    /**
081     * Constructs a {@link CtrCryptoInputStream}.
082     *
083     * @param properties The {@code Properties} class represents a set of
084     *        properties.
085     * @param inputStream the input stream.
086     * @param key crypto key for the cipher.
087     * @param iv Initialization vector for the cipher.
088     * @throws IOException if an I/O error occurs.
089     */
090    public CtrCryptoInputStream(final Properties properties, final InputStream inputStream, final byte[] key,
091            final byte[] iv) throws IOException {
092        this(properties, inputStream, key, iv, 0);
093    }
094
095    /**
096     * Constructs a {@link CtrCryptoInputStream}.
097     *
098     * @param properties The {@code Properties} class represents a set of
099     *        properties.
100     * @param channel the ReadableByteChannel instance.
101     * @param key crypto key for the cipher.
102     * @param iv Initialization vector for the cipher.
103     * @throws IOException if an I/O error occurs.
104     */
105    public CtrCryptoInputStream(final Properties properties, final ReadableByteChannel channel,
106            final byte[] key, final byte[] iv) throws IOException {
107        this(properties, channel, key, iv, 0);
108    }
109
110    /**
111     * Constructs a {@link CtrCryptoInputStream}.
112     *
113     * @param inputStream the input stream.
114     * @param cipher the CryptoCipher instance.
115     * @param bufferSize the bufferSize.
116     * @param key crypto key for the cipher.
117     * @param iv Initialization vector for the cipher.
118     * @throws IOException if an I/O error occurs.
119     */
120    protected CtrCryptoInputStream(final InputStream inputStream, final CryptoCipher cipher,
121            final int bufferSize, final byte[] key, final byte[] iv) throws IOException {
122        this(inputStream, cipher, bufferSize, key, iv, 0);
123    }
124
125    /**
126     * Constructs a {@link CtrCryptoInputStream}.
127     *
128     * @param channel the ReadableByteChannel instance.
129     * @param cipher the cipher instance.
130     * @param bufferSize the bufferSize.
131     * @param key crypto key for the cipher.
132     * @param iv Initialization vector for the cipher.
133     * @throws IOException if an I/O error occurs.
134     */
135    protected CtrCryptoInputStream(final ReadableByteChannel channel, final CryptoCipher cipher,
136            final int bufferSize, final byte[] key, final byte[] iv) throws IOException {
137        this(channel, cipher, bufferSize, key, iv, 0);
138    }
139
140    /**
141     * Constructs a {@link CtrCryptoInputStream}.
142     *
143     * @param input the input data.
144     * @param cipher the CryptoCipher instance.
145     * @param bufferSize the bufferSize.
146     * @param key crypto key for the cipher.
147     * @param iv Initialization vector for the cipher.
148     * @throws IOException if an I/O error occurs.
149     */
150    protected CtrCryptoInputStream(final Input input, final CryptoCipher cipher,
151            final int bufferSize, final byte[] key, final byte[] iv) throws IOException {
152        this(input, cipher, bufferSize, key, iv, 0);
153    }
154
155    /**
156     * Constructs a {@link CtrCryptoInputStream}.
157     *
158     * @param properties The {@code Properties} class represents a set of
159     *        properties.
160     * @param inputStream the InputStream instance.
161     * @param key crypto key for the cipher.
162     * @param iv Initialization vector for the cipher.
163     * @param streamOffset the start offset in the stream.
164     * @throws IOException if an I/O error occurs.
165     */
166    @SuppressWarnings("resource") // The CryptoCipher returned by getCipherInstance() is closed by CtrCryptoInputStream.
167    public CtrCryptoInputStream(final Properties properties, final InputStream inputStream, final byte[] key,
168            final byte[] iv, final long streamOffset) throws IOException {
169        this(inputStream, Utils.getCipherInstance(
170                "AES/CTR/NoPadding", properties),
171                CryptoInputStream.getBufferSize(properties), key, iv, streamOffset);
172    }
173
174    /**
175     * Constructs a {@link CtrCryptoInputStream}.
176     *
177     * @param properties The {@code Properties} class represents a set of
178     *        properties.
179     * @param in the ReadableByteChannel instance.
180     * @param key crypto key for the cipher.
181     * @param iv Initialization vector for the cipher.
182     * @param streamOffset the start offset in the stream.
183     * @throws IOException if an I/O error occurs.
184     */
185    @SuppressWarnings("resource") // The CryptoCipher returned by getCipherInstance() is closed by CtrCryptoInputStream.
186    public CtrCryptoInputStream(final Properties properties, final ReadableByteChannel in,
187            final byte[] key, final byte[] iv, final long streamOffset) throws IOException {
188        this(in, Utils.getCipherInstance(
189                "AES/CTR/NoPadding", properties),
190                CryptoInputStream.getBufferSize(properties), key, iv, streamOffset);
191    }
192
193    /**
194     * Constructs a {@link CtrCryptoInputStream}.
195     *
196     * @param inputStream the InputStream instance.
197     * @param cipher the CryptoCipher instance.
198     * @param bufferSize the bufferSize.
199     * @param key crypto key for the cipher.
200     * @param iv Initialization vector for the cipher.
201     * @param streamOffset the start offset in the stream.
202     * @throws IOException if an I/O error occurs.
203     */
204    protected CtrCryptoInputStream(final InputStream inputStream, final CryptoCipher cipher,
205            final int bufferSize, final byte[] key, final byte[] iv, final long streamOffset)
206            throws IOException {
207        this(new StreamInput(inputStream, bufferSize), cipher, bufferSize, key, iv,
208                streamOffset);
209    }
210
211    /**
212     * Constructs a {@link CtrCryptoInputStream}.
213     *
214     * @param channel the ReadableByteChannel instance.
215     * @param cipher the CryptoCipher instance.
216     * @param bufferSize the bufferSize.
217     * @param key crypto key for the cipher.
218     * @param iv Initialization vector for the cipher.
219     * @param streamOffset the start offset in the stream.
220     * @throws IOException if an I/O error occurs.
221     */
222    protected CtrCryptoInputStream(final ReadableByteChannel channel, final CryptoCipher cipher,
223            final int bufferSize, final byte[] key, final byte[] iv, final long streamOffset)
224            throws IOException {
225        this(new ChannelInput(channel), cipher, bufferSize, key, iv, streamOffset);
226    }
227
228    /**
229     * Constructs a {@link CtrCryptoInputStream}.
230     *
231     * @param input the input data.
232     * @param cipher the CryptoCipher instance.
233     * @param bufferSize the bufferSize.
234     * @param key crypto key for the cipher.
235     * @param iv Initialization vector for the cipher.
236     * @param streamOffset the start offset in the stream.
237     * @throws IOException if an I/O error occurs.
238     */
239    protected CtrCryptoInputStream(final Input input, final CryptoCipher cipher,
240            final int bufferSize, final byte[] key, final byte[] iv, final long streamOffset)
241            throws IOException {
242        super(input, cipher, bufferSize, new SecretKeySpec(key, "AES"),
243                new IvParameterSpec(iv));
244
245        this.initIV = iv.clone();
246        this.iv = iv.clone();
247
248        CryptoInputStream.checkStreamCipher(cipher);
249
250        resetStreamOffset(streamOffset);
251    }
252
253    /**
254     * Overrides the {@link CryptoInputStream#skip(long)}. Skips over and
255     * discards {@code n} bytes of data from this input stream.
256     *
257     * @param n the number of bytes to be skipped.
258     * @return the actual number of bytes skipped.
259     * @throws IOException if an I/O error occurs.
260     */
261    @Override
262    public long skip(long n) throws IOException {
263        Utils.checkArgument(n >= 0, "Negative skip length.");
264        checkStream();
265
266        if (n == 0) {
267            return 0;
268        } else if (n <= outBuffer.remaining()) {
269            final int pos = outBuffer.position() + (int) n;
270            outBuffer.position(pos);
271            return n;
272        } else {
273            /*
274             * Subtract outBuffer.remaining() to see how many bytes we need to
275             * skip in the underlying stream. Add outBuffer.remaining() to the
276             * actual number of skipped bytes in the underlying stream to get
277             * the number of skipped bytes from the user's point of view.
278             */
279            n -= outBuffer.remaining();
280            long skipped = input.skip(n);
281            if (skipped < 0) {
282                skipped = 0;
283            }
284            final long pos = streamOffset + skipped;
285            skipped += outBuffer.remaining();
286            resetStreamOffset(pos);
287            return skipped;
288        }
289    }
290
291    /**
292     * Overrides the {@link CtrCryptoInputStream#read(ByteBuffer)}. Reads a
293     * sequence of bytes from this channel into the given buffer.
294     *
295     * @param buf The buffer into which bytes are to be transferred.
296     * @return The number of bytes read, possibly zero, or {@code -1} if the
297     *         channel has reached end-of-stream.
298     * @throws IOException if an I/O error occurs.
299     */
300    @Override
301    public int read(final ByteBuffer buf) throws IOException {
302        checkStream();
303        int unread = outBuffer.remaining();
304        if (unread <= 0) { // Fill the unread decrypted data buffer firstly
305            final int n = input.read(inBuffer);
306            if (n <= 0) {
307                return n;
308            }
309
310            streamOffset += n; // Read n bytes
311            if (buf.isDirect() && buf.remaining() >= inBuffer.position()
312                    && padding == 0) {
313                // Use buf as the output buffer directly
314                decryptInPlace(buf);
315                padding = postDecryption(streamOffset);
316                return n;
317            }
318            // Use outBuffer as the output buffer
319            decrypt();
320            padding = postDecryption(streamOffset);
321        }
322
323        // Copy decrypted data from outBuffer to buf
324        unread = outBuffer.remaining();
325        final int toRead = buf.remaining();
326        if (toRead <= unread) {
327            final int limit = outBuffer.limit();
328            outBuffer.limit(outBuffer.position() + toRead);
329            buf.put(outBuffer);
330            outBuffer.limit(limit);
331            return toRead;
332        }
333        buf.put(outBuffer);
334        return unread;
335    }
336
337    /**
338     * Seeks the stream to a specific position relative to start of the under
339     * layer stream.
340     *
341     * @param position the given position in the data.
342     * @throws IOException if an I/O error occurs.
343     */
344    public void seek(final long position) throws IOException {
345        Utils.checkArgument(position >= 0, "Cannot seek to negative offset.");
346        checkStream();
347        /*
348         * If data of target pos in the underlying stream has already been read
349         * and decrypted in outBuffer, we just need to re-position outBuffer.
350         */
351        if (position >= getStreamPosition() && position <= getStreamOffset()) {
352            final int forward = (int) (position - getStreamPosition());
353            if (forward > 0) {
354                outBuffer.position(outBuffer.position() + forward);
355            }
356        } else {
357            input.seek(position);
358            resetStreamOffset(position);
359        }
360    }
361
362    /**
363     * Gets the offset of the stream.
364     *
365     * @return the stream offset.
366     */
367    protected long getStreamOffset() {
368        return streamOffset;
369    }
370
371    /**
372     * Sets the offset of stream.
373     *
374     * @param streamOffset the stream offset.
375     */
376    protected void setStreamOffset(final long streamOffset) {
377        this.streamOffset = streamOffset;
378    }
379
380    /**
381     * Gets the position of the stream.
382     *
383     * @return the position of the stream.
384     */
385    protected long getStreamPosition() {
386        return streamOffset - outBuffer.remaining();
387    }
388
389    /**
390     * Decrypts more data by reading the under layer stream. The decrypted data
391     * will be put in the output buffer.
392     *
393     * @return The number of decrypted data. -1 if end of the decrypted stream.
394     * @throws IOException if an I/O error occurs.
395     */
396    @Override
397    protected int decryptMore() throws IOException {
398        final int n = input.read(inBuffer);
399        if (n <= 0) {
400            return n;
401        }
402
403        streamOffset += n; // Read n bytes
404        decrypt();
405        padding = postDecryption(streamOffset);
406        return outBuffer.remaining();
407    }
408
409    /**
410     * Does the decryption using inBuffer as input and outBuffer as output. Upon
411     * return, inBuffer is cleared; the decrypted data starts at
412     * outBuffer.position() and ends at outBuffer.limit().
413     *
414     * @throws IOException if an I/O error occurs.
415     */
416    @Override
417    protected void decrypt() throws IOException {
418        Utils.checkState(inBuffer.position() >= padding);
419        if (inBuffer.position() == padding) {
420            // There is no real data in inBuffer.
421            return;
422        }
423
424        inBuffer.flip();
425        outBuffer.clear();
426        decryptBuffer(outBuffer);
427        inBuffer.clear();
428        outBuffer.flip();
429
430        if (padding > 0) {
431            /*
432             * The plain text and cipher text have a 1:1 mapping, they start at
433             * the same position.
434             */
435            outBuffer.position(padding);
436        }
437    }
438
439    /**
440     * Does the decryption using inBuffer as input and buf as output. Upon
441     * return, inBuffer is cleared; the buf's position will be equal to
442     * <i>p</i>&nbsp;{@code +}&nbsp;<i>n</i> where <i>p</i> is the position
443     * before decryption, <i>n</i> is the number of bytes decrypted. The buf's
444     * limit will not have changed.
445     *
446     * @param buf The buffer into which bytes are to be transferred.
447     * @throws IOException if an I/O error occurs.
448     */
449    protected void decryptInPlace(final ByteBuffer buf) throws IOException {
450        Utils.checkState(inBuffer.position() >= padding);
451        Utils.checkState(buf.isDirect());
452        Utils.checkState(buf.remaining() >= inBuffer.position());
453        Utils.checkState(padding == 0);
454
455        if (inBuffer.position() == padding) {
456            // There is no real data in inBuffer.
457            return;
458        }
459        inBuffer.flip();
460        decryptBuffer(buf);
461        inBuffer.clear();
462    }
463
464    /**
465     * Decrypts all data in buf: total n bytes from given start position. Output
466     * is also buf and same start position. buf.position() and buf.limit()
467     * should be unchanged after decryption.
468     *
469     * @param buf The buffer into which bytes are to be transferred.
470     * @param offset the start offset in the data.
471     * @param len the maximum number of decrypted data bytes to read.
472     * @throws IOException if an I/O error occurs.
473     */
474    protected void decrypt(final ByteBuffer buf, final int offset, final int len)
475            throws IOException {
476        final int pos = buf.position();
477        final int limit = buf.limit();
478        int n = 0;
479        while (n < len) {
480            buf.position(offset + n);
481            buf.limit(offset + n + Math.min(len - n, inBuffer.remaining()));
482            inBuffer.put(buf);
483            // Do decryption
484            try {
485                decrypt();
486                buf.position(offset + n);
487                buf.limit(limit);
488                n += outBuffer.remaining();
489                buf.put(outBuffer);
490            } finally {
491                padding = postDecryption(streamOffset - (len - n));
492            }
493        }
494        buf.position(pos);
495    }
496
497    /**
498     * This method is executed immediately after decryption. Checks whether
499     * cipher should be updated and recalculate padding if needed.
500     *
501     * @param position the given position in the data..
502     * @return the byte.
503     * @throws IOException if an I/O error occurs.
504     */
505    protected byte postDecryption(final long position) throws IOException {
506        byte padding = 0;
507        if (cipherReset) {
508            /*
509             * This code is generally not executed since the cipher usually
510             * maintains cipher context (e.g. the counter) internally. However,
511             * some implementations can't maintain context so a re-init is
512             * necessary after each decryption call.
513             */
514            resetCipher(position);
515            padding = getPadding(position);
516            inBuffer.position(padding);
517        }
518        return padding;
519    }
520
521    /**
522     * Gets the initialization vector.
523     *
524     * @return the initIV.
525     */
526    protected byte[] getInitIV() {
527        return initIV;
528    }
529
530    /**
531     * Gets the counter for input stream position.
532     *
533     * @param position the given position in the data.
534     * @return the counter for input stream position.
535     */
536    protected long getCounter(final long position) {
537        return position / cipher.getBlockSize();
538    }
539
540    /**
541     * Gets the padding for input stream position.
542     *
543     * @param position the given position in the data.
544     * @return the padding for input stream position.
545     */
546    protected byte getPadding(final long position) {
547        return (byte) (position % cipher.getBlockSize());
548    }
549
550    /**
551     * Overrides the {@link CtrCryptoInputStream#initCipher()}. Initializes the
552     * cipher.
553     */
554    @Override
555    protected void initCipher() {
556        // Do nothing for initCipher
557        // Will reset the cipher when reset the stream offset
558    }
559
560    /**
561     * Calculates the counter and iv, resets the cipher.
562     *
563     * @param position the given position in the data.
564     * @throws IOException if an I/O error occurs.
565     */
566    protected void resetCipher(final long position) throws IOException {
567        final long counter = getCounter(position);
568        CtrCryptoInputStream.calculateIV(initIV, counter, iv);
569        try {
570            cipher.init(Cipher.DECRYPT_MODE, key, new IvParameterSpec(iv));
571        } catch (final GeneralSecurityException e) {
572            throw new IOException(e);
573        }
574        cipherReset = false;
575    }
576
577    /**
578     * Resets the underlying stream offset; clear {@link #inBuffer} and
579     * {@link #outBuffer}. This Typically happens during {@link #skip(long)}.
580     *
581     * @param offset the offset of the stream.
582     * @throws IOException if an I/O error occurs.
583     */
584    protected void resetStreamOffset(final long offset) throws IOException {
585        streamOffset = offset;
586        inBuffer.clear();
587        outBuffer.clear();
588        outBuffer.limit(0);
589        resetCipher(offset);
590        padding = getPadding(offset);
591        inBuffer.position(padding); // Set proper position for input data.
592    }
593
594    /**
595     * Does the decryption using out as output.
596     *
597     * @param out the output ByteBuffer.
598     * @throws IOException if an I/O error occurs.
599     */
600    protected void decryptBuffer(final ByteBuffer out) throws IOException {
601        final int inputSize = inBuffer.remaining();
602        try {
603            final int n = cipher.update(inBuffer, out);
604            if (n < inputSize) {
605                /**
606                 * Typically code will not get here. CryptoCipher#update will
607                 * consume all input data and put result in outBuffer.
608                 * CryptoCipher#doFinal will reset the cipher context.
609                 */
610                cipher.doFinal(inBuffer, out);
611                cipherReset = true;
612            }
613        } catch (final GeneralSecurityException e) {
614            throw new IOException(e);
615        }
616    }
617
618    /**
619     * <p>
620     * This method is only for Counter (CTR) mode. Generally the CryptoCipher
621     * calculates the IV and maintain encryption context internally.For example
622     * a Cipher will maintain its encryption context internally when we do
623     * encryption/decryption using the CryptoCipher#update interface.
624     * </p>
625     * <p>
626     * Encryption/Decryption is not always on the entire file. For example, in
627     * Hadoop, a node may only decrypt a portion of a file (i.e. a split). In
628     * these situations, the counter is derived from the file position.
629     * </p>
630     * The IV can be calculated by combining the initial IV and the counter with
631     * a lossless operation (concatenation, addition, or XOR).
632     *
633     * @see <a
634     *      href="http://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Counter_.28CTR.29">
635     *      http://en.wikipedia.org/wiki/Block_cipher_mode_of_operation#Counter_.28CTR.29</a>
636     *
637     * @param initIV initial IV
638     * @param counter counter for input stream position
639     * @param IV the IV for input stream position
640     */
641    static void calculateIV(final byte[] initIV, long counter, final byte[] IV) {
642        Utils.checkArgument(initIV.length == CryptoCipherFactory.AES_BLOCK_SIZE);
643        Utils.checkArgument(IV.length == CryptoCipherFactory.AES_BLOCK_SIZE);
644
645        int i = IV.length; // IV length
646        int j = 0; // counter bytes index
647        int sum = 0;
648        while (i-- > 0) {
649            // (sum >>> Byte.SIZE) is the carry for addition
650            sum = (initIV[i] & 0xff) + (sum >>> Byte.SIZE); // NOPMD
651            if (j++ < 8) { // Big-endian, and long is 8 bytes length
652                sum += (byte) counter & 0xff;
653                counter >>>= 8;
654            }
655            IV[i] = (byte) sum;
656        }
657    }
658}