# JEB script - demo AST API
# Decrypt and replace strings
# See

import sys
import os
import time
from jeb.api import IScript
from jeb.api import EngineOption
from jeb.api.ui import View
from jeb.api.dex import Dex
from jeb.api.ast import Class, Field, Method, Call, Constant, StaticField, NewArray

class ASTDecryptStrings(IScript):

  def run(self, jeb):
    self.jeb = jeb
    self.dex = self.jeb.getDex()
    self.cstbuilder = Constant.Builder(jeb)

    self.csig = 'MainActivity'
    self.encbytes = []
    self.mname_decrypt = None
    # the encryption keys could be determined by analyzing the decryption method
    self.keys = [409, 62, -8]

    r = jeb.decompileClass(self.csig)
    if not r:
      print 'Could not find class "%s"' % csig

    c = jeb.getDecompiledClassTree(self.csig)

    wanted_flags = Dex.ACC_PRIVATE|Dex.ACC_STATIC|Dex.ACC_FINAL
    for f in c.getFields():
      fsig = f.getSignature()
      if fsig.endswith(':[B'):
        fd = self.dex.getFieldData(fsig)
        if fd.getAccessFlags() & wanted_flags == wanted_flags:
          print 'Found field:', fsig

          findex = fd.getFieldIndex()
          for mindex in self.dex.getFieldReferences(findex):
            mname = self.dex.getMethod(mindex).getName(False)
            if mname != '':
              self.mname_decrypt = mname

          for m2 in c.getMethods():
            if m2.getName() == '':
              s0 = m2.getBody().get(0)
              if isinstance(s0.getLeft(), StaticField) and s0.getLeft().getField().getSignature() == f.getSignature():
                array = s0.getRight()
                if isinstance(array, NewArray):
                  for v in array .getInitialValues():

    if len(self.encbytes) == 0:
      print 'Encrypted strings byte array not found'

    if not self.mname_decrypt:
      print 'Decryption method was not found'

    for m in c.getMethods():
      print 'Decrypting strings in method: %s' % m.getName()

  def decryptMethodStrings(self, m):
    block = m.getBody()
    i = 0
    while i < block.size():
      stm = block.get(i)
      self.checkElement(block, stm)
      i += 1

  def checkElement(self, parent, e):
    if isinstance(e, Call):
      mname = e.getMethod().getName()
      if mname == self.mname_decrypt:
        v = []
        for arg in e.getArguments():
          if isinstance(arg, Constant):
        if len(v) == 3:
          decrypted_string = self.decrypt(v[0], v[1], v[2])
          parent.replaceSubElement(e, self.cstbuilder.buildString(decrypted_string))
          print '  Decrypted string: %s' % repr(decrypted_string)

    for subelt in e.getSubElements():
      if isinstance(subelt, Class) or isinstance(subelt, Field) or isinstance(subelt, Method):
      self.checkElement(e, subelt)

  def decrypt(self, length, curChar, pos):
    length += self.keys[0]
    curChar += self.keys[1]
    r = ''
    for i in range(length):
      r += chr(curChar & 0xFF)
      if pos >= len(self.encbytes):
      curEncodedChar = self.encbytes[pos]
      pos += 1
      curChar = curChar + curEncodedChar + self.keys[2]
    return r