From e34e2acd7dfcc04190decb3c05477bc4c997c8f8 Mon Sep 17 00:00:00 2001 From: ryneeverett Date: Tue, 5 Jan 2016 12:35:54 -0500 Subject: [PATCH] Support recursive/nested inputs. --- mkcodes.py | 38 ++++++++++++++++++++------------------ tests/data/nest/deep.md | 3 +++ tests/test.py | 16 +++++++++++++++- 3 files changed, 38 insertions(+), 19 deletions(-) create mode 100644 tests/data/nest/deep.md diff --git a/mkcodes.py b/mkcodes.py index 9602e94..269fb41 100644 --- a/mkcodes.py +++ b/mkcodes.py @@ -14,24 +14,16 @@ from markdown.treeprocessors import Treeprocessor -def iglob(input): - try: - return glob.iglob(input + '/**', recursive=True) - except TypeError: - warnings.warn('In python<3.5, inputs are not recursive.') - return glob.iglob(input + '/*') - - def default_state(): return [], True, False -def github_codeblocks(filename, safe): +def github_codeblocks(filepath, safe): codeblocks = [] codeblock_re = r'^```.*' codeblock_open_re = r'^```(`*)(py|python){0}$'.format('' if safe else '?') - with open(filename, 'r') as f: + with open(filepath, 'r') as f: block = [] python = True in_codeblock = False @@ -55,7 +47,7 @@ def github_codeblocks(filename, safe): return codeblocks -def markdown_codeblocks(filename, safe): +def markdown_codeblocks(filepath, safe): import markdown codeblocks = [] @@ -75,7 +67,7 @@ def extendMarkdown(self, md, md_globals): doctestextension = DoctestExtension() markdowner = markdown.Markdown(extensions=[doctestextension]) - markdowner.convertFile(filename, output=os.devnull) + markdowner.convertFile(filepath, output=os.devnull) return codeblocks @@ -84,12 +76,20 @@ def is_markdown(f): return os.path.splitext(f)[1] in markdown_extensions +def get_nested_files(directory, depth): + for i in glob.iglob(directory + '/*'): + if os.path.isdir(i): + yield from get_nested_files(i, depth+1) + elif is_markdown(i): + yield (i, depth) + + def get_files(inputs): for i in inputs: if os.path.isdir(i): - yield from filter(is_markdown, iglob(i)) + yield from get_nested_files(i, 0) elif is_markdown(i): - yield i + yield (i, 0) @click.command() @@ -103,11 +103,13 @@ def get_files(inputs): def main(inputs, output, github, safe): collect_codeblocks = github_codeblocks if github else markdown_codeblocks - for filename in get_files(inputs): - code = '\n\n'.join(collect_codeblocks(filename, safe)) + for filepath, depth in get_files(inputs): + code = '\n\n'.join(collect_codeblocks(filepath, safe)) + + filename = os.path.splitext(filepath)[0] + outputname = os.sep.join(filename.split(os.sep)[-1-depth:]) - inputname = os.path.splitext(os.path.basename(filename))[0] - outputfilename = output.format(name=inputname) + outputfilename = output.format(name=outputname) outputdir = os.path.dirname(outputfilename) if not os.path.exists(outputdir): diff --git a/tests/data/nest/deep.md b/tests/data/nest/deep.md new file mode 100644 index 0000000..11edf58 --- /dev/null +++ b/tests/data/nest/deep.md @@ -0,0 +1,3 @@ +```py +print('Hello World!') +``` diff --git a/tests/test.py b/tests/test.py index e28bec1..3fd404f 100644 --- a/tests/test.py +++ b/tests/test.py @@ -77,6 +77,19 @@ def test_directory(self): self.call(inputfile='tests/data') self.assertTrue(os.path.exists(self.output)) + def test_directory_recursive(self): + try: + subprocess.call([ + 'mkcodes', '--output', 'tests/{name}.py', '--github', + 'tests/data']) + self.assertTrue(os.path.exists('tests/some.py')) + self.assertTrue(os.path.exists('tests/other.py')) + self.assertTrue(os.path.exists('tests/nest/deep.py')) + finally: + self.remove('tests/some.py') + self.remove('tests/other.py') + shutil.rmtree('tests/nest', ignore_errors=True) + def test_multiple(self): try: subprocess.call([ @@ -91,6 +104,7 @@ def test_multiple(self): self.assertFileEqual('tests/other.py', """\ qux = 4 """) + self.assertFalse(os.path.exists('tests/nest/deep.py')) finally: self.remove('tests/some.py') self.remove('tests/other.py') @@ -107,7 +121,7 @@ def test_unexistant_output_directory(self): backticks = range(5, 7) """) finally: - shutil.rmtree('tests/unexistant') + shutil.rmtree('tests/unexistant', ignore_errors=True) @unittest.skip def test_glob(self):