From 8b63269162f036241a0af7cd1b1f05ab8fedf9db Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 21:04:30 +0800 Subject: [PATCH 01/14] =?UTF-8?q?feat:=20=E6=B7=BB=E5=8A=A0=E5=A4=96?= =?UTF-8?q?=E9=83=A8=E7=AE=97=E5=AD=90=E6=94=AF=E6=8C=81=E4=BB=A5=E8=B7=B3?= =?UTF-8?q?=E8=BF=87=E5=A4=8D=E6=9D=82=E7=AC=A6=E5=8F=B7=E5=B1=95=E5=BC=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- design.md | 31 ++++++++++++ manual.md | 81 +++++++++++++++++++++++++++++- sympy_codegen.py | 125 +++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 233 insertions(+), 4 deletions(-) diff --git a/design.md b/design.md index 2701435..cfc3b4e 100644 --- a/design.md +++ b/design.md @@ -45,3 +45,34 @@ Implement the following methods in the `Element` class: ### 5.2 Standalone Mathematical Script Use `--task custom` with a script containing `get_model()` to generate arbitrary kernels. + +## 6. External Operators + +### 6.1 Motivation +For certain computations (e.g., 12×12 matrix inversion), letting SymPy symbolically expand the full expression is prohibitively slow and generates enormous code. External operators allow these computations to be **skipped** during symbolic expansion and instead emitted as **function calls** in the generated code, assuming an external implementation exists. + +### 6.2 Architecture +The external operator feature has three layers: + +1. **`ExternalOperator`** — Describes the signature of an externally implemented operator: + - `name`: Operator identifier (e.g., `"inv12"`) + - `n_inputs` / `n_outputs`: Number of scalar input/output elements + - `cpp_func`, `fortran_func`, `jax_func`: Function names for each backend + +2. **`ExternalCall`** — Records a specific invocation of an external operator within a `MathModel`, binding concrete SymPy input expressions to output placeholder symbols. + +3. **`external_call()`** — Helper function that creates placeholder output symbols and registers the call in the model. + +All external operators follow a unified **vector-in / vector-out** interface: +- C++: `void func(const double* in, double* out)` +- Fortran: `subroutine func(in_vec, out_vec)` + +### 6.3 Code Generation Order +When external calls are present, the generated code structure becomes: +1. Unpack input variables +2. **External operator calls** (declare input/output arrays, assign inputs, call function, unpack outputs) +3. Normal CSE chunks +4. Output assignment + +### 6.4 CSE Interaction +External operator outputs are plain SymPy symbols (placeholders), so CSE does not attempt to expand them. Input expressions are printed directly (not included in CSE), keeping the first implementation simple and reliable. diff --git a/manual.md b/manual.md index b23867d..a3c2fa6 100644 --- a/manual.md +++ b/manual.md @@ -38,7 +38,86 @@ For complex elements, a single monolithic kernel can be slow to compile and hard 3. **Assembly**: Integration point contribution ($B^T D B \det(J) W$). 4. **Lumped Mass**: Element-level mass distribution for explicit dynamics. -### 3.2 Fast Validation Solvers (JAX) +### 3.2 External Operators + +External operators allow you to skip symbolic expansion of expensive computations (e.g., large matrix inversion) and instead emit function calls to externally implemented routines. + +#### 3.2.1 Define an External Operator + +```python +from sympy_codegen import MathModel, ExternalOperator, external_call + +ext_ops = { + "inv12": ExternalOperator( + name="inv12", + n_inputs=144, # 12×12 = 144 elements + n_outputs=144, + cpp_func="fea_inv12", + fortran_func="fea_inv12", + jax_func=None, # Not supported in JAX + ) +} +``` + +#### 3.2.2 Use in a MathModel + +```python +import sympy as sp + +A_syms = list(sp.symbols("A_0:144", real=True)) + +model = MathModel( + inputs=A_syms, + outputs=[], + name="kernel_with_inv", + external_ops=ext_ops, +) + +# Register the external call — returns 144 placeholder symbols +invA = external_call(model, "inv12", A_syms) + +# Use the output symbols in subsequent expressions +B_syms = list(sp.symbols("B_0:144", real=True)) +model.inputs = A_syms + B_syms + +outputs = [] +for i in range(12): + for j in range(12): + val = sum(invA[i*12 + k] * B_syms[k*12 + j] for k in range(12)) + outputs.append(val) +model.outputs = outputs +``` + +#### 3.2.3 Multiple Calls with Different Prefixes + +```python +invA = external_call(model, "inv12", A_syms, prefix="invA") +invB = external_call(model, "inv12", B_syms, prefix="invB") +``` + +#### 3.2.4 Generated Code Example (C++) + +```cpp +// --- External Operator: inv12 --- +double inv12_in[144]; +double inv12_out[144]; +inv12_in[0] = A_0; +// ... +inv12_in[143] = A_143; +fea_inv12(inv12_in, inv12_out); +double inv12_0 = inv12_out[0]; +// ... +double inv12_143 = inv12_out[143]; + +// --- Chunk 0 (normal CSE) --- +out[0] = ...; +``` + +#### 3.2.5 JAX Limitation + +If `jax_func` is `None`, generating JAX code for a model that uses that operator will raise a `ValueError`. + +### 3.3 Fast Validation Solvers (JAX) Two scripts are provided for rapid verification of generated JAX kernels: - `static.py`: Solves linear static problems using implicit integration. - `explicit.py`: Solves dynamic problems using the Central Difference method (explicit). diff --git a/sympy_codegen.py b/sympy_codegen.py index 45dbaa5..41af213 100644 --- a/sympy_codegen.py +++ b/sympy_codegen.py @@ -24,6 +24,26 @@ # --------------------------------------------------------------------------- # 辅助数据结构:用于 CSE 结果缓存和跨后端共享 # --------------------------------------------------------------------------- +class ExternalOperator: + """描述一个外部实现的算子(向量输入、向量输出)。""" + def __init__(self, name, n_inputs, n_outputs, cpp_func=None, fortran_func=None, jax_func=None): + self.name = name # 算子名称,如 "inv12" + self.n_inputs = n_inputs # 输入元素个数,如 144 + self.n_outputs = n_outputs # 输出元素个数,如 144 + self.cpp_func = cpp_func or name # C++ 函数名 + self.fortran_func = fortran_func or name # Fortran 子程序名 + self.jax_func = jax_func # JAX 函数名,None 表示不支持 + + +class ExternalCall: + """记录 MathModel 中一次外部算子调用。""" + def __init__(self, op_name, input_exprs, output_symbols, prefix=None): + self.op_name = op_name # 对应 ExternalOperator 的 name + self.input_exprs = list(input_exprs) # SymPy 表达式列表(输入) + self.output_symbols = list(output_symbols) # SymPy 符号列表(输出占位) + self.prefix = prefix or op_name # 变量名前缀,避免多次调用冲突 + + class LoweredChunk: """单个 chunk 的 CSE 结果""" def __init__(self, chunk_index: int, start_index: int, sub_exprs: list, simplified_outputs: list): @@ -35,10 +55,13 @@ def __init__(self, chunk_index: int, start_index: int, sub_exprs: list, simplifi class LoweredModel: """整个模型经过 CSE lowering 后的结果""" - def __init__(self, model_name: str, chunk_size: int, chunks: list): + def __init__(self, model_name: str, chunk_size: int, chunks: list, + external_calls=None, external_ops=None): self.model_name = model_name self.chunk_size = chunk_size self.chunks = chunks # list of LoweredChunk + self.external_calls = external_calls or [] # list[ExternalCall] + self.external_ops = external_ops or {} # dict[str, ExternalOperator] class CachedPrinter: @@ -191,13 +214,48 @@ def _print_Pow(self, expr): class MathModel: """数据容器:存储数学定义""" - def __init__(self, inputs, outputs, name="kernel", input_names=None, output_names=None, is_operator=False): + def __init__(self, inputs, outputs, name="kernel", input_names=None, output_names=None, + is_operator=False, external_ops=None, external_calls=None): self.inputs = inputs # SymPy 符号列表 self.outputs = outputs # SymPy 表达式列表 self.name = name self.input_names = input_names or [str(s) for s in inputs] self.output_names = output_names or [f"out[{i}]" for i in range(len(outputs))] self.is_operator = is_operator # 是否作为算子生成(可能包含SIMD优化等) + self.external_ops = external_ops or {} # dict[str, ExternalOperator] + self.external_calls = external_calls or [] # list[ExternalCall] + + +def external_call(model, op_name, input_exprs, n_outputs=None, prefix=None): + """ + 在 MathModel 中注册一次外部算子调用,返回输出符号列表。 + + Args: + model: MathModel 实例 + op_name: model.external_ops 中的键名 + input_exprs: 输入表达式列表 (SymPy 表达式) + n_outputs: 输出数量(如不提供则从 external_ops 查找) + prefix: 输出符号前缀(默认等于 op_name,多次调用同一算子时需不同前缀) + + Returns: + list[sp.Symbol]: 输出占位符号列表,可用于后续表达式 + """ + op = model.external_ops[op_name] + n_outputs = n_outputs or op.n_outputs + prefix = prefix or op_name + + output_symbols = list(sp.symbols(f"{prefix}_0:{n_outputs}")) + + model.external_calls.append( + ExternalCall( + op_name=op_name, + input_exprs=list(input_exprs), + output_symbols=output_symbols, + prefix=prefix, + ) + ) + + return output_symbols class FEACompiler: @@ -226,7 +284,9 @@ def lower_model(model: MathModel, chunk_size: int) -> LoweredModel: ) ) - return LoweredModel(model.name, chunk_size, chunks) + return LoweredModel(model.name, chunk_size, chunks, + external_calls=model.external_calls, + external_ops=model.external_ops) # ========================================================================= # Chunk Size 策略:根据模型规模和目标平台决定 chunk size @@ -463,6 +523,23 @@ def _to_jax(model, lowered=None, chunk_size=None, cse_strategy="auto"): printer = CachedPrinter(JaxPrinter()) all_simplified_outputs = [] + + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + if op.jax_func is None: + raise ValueError( + f"External operator '{call.op_name}' has no JAX implementation. " + f"Cannot generate JAX code for model '{model.name}'." + ) + lines.append(f" # --- External Operator: {call.op_name} ---") + in_parts = ", ".join(printer.doprint(e) for e in call.input_exprs) + lines.append(f" {call.prefix}_in = jnp.array([{in_parts}])") + lines.append(f" {call.prefix}_out = {op.jax_func}({call.prefix}_in)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out[{i}]") + lines.append("") # 使用 lowered 结果 for chunk in lowered.chunks: @@ -535,6 +612,24 @@ def _to_source(model, is_cuda=False, lowered=None, chunk_size=None, cse_strategy # 初始化带缓存的 Printer printer = CachedPrinter(FEACodePrinter()) + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + body_lines.append(f" // --- External Operator: {call.op_name} ---") + body_lines.append(f" double {call.prefix}_in[{op.n_inputs}];") + body_lines.append(f" double {call.prefix}_out[{op.n_outputs}];") + + for i, expr in enumerate(call.input_exprs): + body_lines.append(f" {call.prefix}_in[{i}] = {printer.doprint(expr)};") + + body_lines.append(f" {op.cpp_func}({call.prefix}_in, {call.prefix}_out);") + + for i, sym in enumerate(call.output_symbols): + body_lines.append(f" double {sym} = {call.prefix}_out[{i}];") + + body_lines.append("") + # 使用 lowered 结果 for chunk in lowered.chunks: body_lines.append(f"\n // --- Chunk {chunk.chunk_index} ---") @@ -656,6 +751,30 @@ def _fortran_declare(type_decl, vars_list, indent=" "): if s.isidentifier(): lines.append(f" {s} = in_vec({i + 1})") + # 外部算子调用 + if lowered.external_calls: + lines.append("") + lines.append(" ! --- External Operator Calls ---") + # 先声明所有外部算子相关的变量 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.extend(_fortran_declare("double precision", + [f"{call.prefix}_in({op.n_inputs})", f"{call.prefix}_out({op.n_outputs})"], " ")) + out_var_names = [str(sym) for sym in call.output_symbols] + lines.extend(_fortran_declare("double precision", out_var_names, " ")) + + # 然后赋值和调用 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.append(f" ! External Operator: {call.op_name}") + for i, expr in enumerate(call.input_exprs): + lines.append(f" {call.prefix}_in({i + 1}) = {printer.doprint(expr)}") + lines.append(f" call {op.fortran_func}({call.prefix}_in, {call.prefix}_out)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out({i + 1})") + + lines.append("") + lines.append(" ! --- Local Variables for CSE ---") # 使用 lowered 结果 From 3c6afdd5fc16d126ad7cf81986fe30a852749650 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 22:16:42 +0800 Subject: [PATCH 02/14] =?UTF-8?q?refactor:=E6=8B=86=E5=88=86sympy=5Fcodege?= =?UTF-8?q?n.py?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/__init__.py | 22 + codegen/cli.py | 202 +++++++++ codegen/compiler.py | 594 ++++++++++++++++++++++++ codegen/loader.py | 24 + codegen/lowered.py | 52 +++ codegen/model.py | 49 ++ codegen/printer.py | 133 ++++++ sympy_codegen.py | 1059 +------------------------------------------ 8 files changed, 1082 insertions(+), 1053 deletions(-) create mode 100644 codegen/__init__.py create mode 100644 codegen/cli.py create mode 100644 codegen/compiler.py create mode 100644 codegen/loader.py create mode 100644 codegen/lowered.py create mode 100644 codegen/model.py create mode 100644 codegen/printer.py diff --git a/codegen/__init__.py b/codegen/__init__.py new file mode 100644 index 0000000..810929d --- /dev/null +++ b/codegen/__init__.py @@ -0,0 +1,22 @@ +"""codegen 包 — 对外 re-export 所有公开符号,保证 from sympy_codegen import ... 仍然可用。""" + +from codegen.model import ExternalOperator, ExternalCall, MathModel, external_call +from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter +from codegen.printer import FEACodePrinter, FEAFortranPrinter +from codegen.compiler import FEACompiler +from codegen.loader import load_element, load_material + +__all__ = [ + "ExternalOperator", + "ExternalCall", + "MathModel", + "external_call", + "LoweredChunk", + "LoweredModel", + "CachedPrinter", + "FEACodePrinter", + "FEAFortranPrinter", + "FEACompiler", + "load_element", + "load_material", +] diff --git a/codegen/cli.py b/codegen/cli.py new file mode 100644 index 0000000..02564d3 --- /dev/null +++ b/codegen/cli.py @@ -0,0 +1,202 @@ +import argparse +import importlib.util +import json as _json +import os +import sys +from pathlib import Path + +from codegen.compiler import FEACompiler +from codegen.loader import load_element, load_material + + +def _default_output(model_name: str, target: str) -> str: + t = target.lower() + if t == "jax": + return f"{model_name}_gen.py" + if t in ("cpp", "c++"): + return f"{model_name}_gen.cpp" + if t == "cuda": + return f"{model_name}_gen.cu" + if t == "fortran": + return f"{model_name}_gen.f90" + return f"{model_name}_{t}.txt" + +def main(): + parser = argparse.ArgumentParser( + description="SymPy FEA 代码生成器 (混合解耦架构)" + ) + parser.add_argument( + "--task", + required=True, + choices=["constitutive", "stiffness", "mass", "custom"], + help="生成任务: 'constitutive' (材料D矩阵), 'stiffness' (单元Ke矩阵), 'mass' (质量矩阵), 或 'custom' (自定义数学模型)", + ) + parser.add_argument( + "--element", "-e", + help="单元名称 (e.g., 'tet4'), required for --task=stiffness", + ) + parser.add_argument( + "--material", "-m", + help="材料名称 (e.g., 'isotropic'), required for --task=constitutive", + ) + parser.add_argument( + "--script", "-s", + help="Python 脚本路径 (用于 --task=custom). 脚本中需要提供 get_model() 函数返回 MathModel.", + ) + parser.add_argument( + "--target", "-t", + required=True, + choices=["jax", "cpp", "cuda", "fortran", "all"], + help="目标语言:jax / cpp / cuda / fortran / all", + ) + parser.add_argument( + "--output", "-o", + default=None, + help="输出文件路径(默认根据任务和名称生成)", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=None, + help="CSE chunk size. 如果省略,则使用 cse-strategy 决定。", + ) + parser.add_argument( + "--cse-strategy", + choices=["auto", "fixed"], + default="auto", + help="CSE chunk sizing 策略。'auto' 根据输出规模自动调整,'fixed' 使用固定默认值。", + ) + parser.add_argument( + "--test", + action="store_true", + default=False, + help="同时生成 CI 测试资产(C++/Fortran wrapper、test_driver.py、build 脚本)", + ) + parser.add_argument( + "--test-output-dir", + default=None, + help="测试资产输出目录(仅在 --test 启用时有效,默认与 --output 相同)", + ) + args = parser.parse_args() + + if args.task == "constitutive": + if not args.material: + parser.error("--material is required for --task=constitutive") + material = load_material(args.material) + model = material.get_constitutive_model() + models_to_compile = {model.name: model} + + elif args.task == "stiffness": + if not args.element: + parser.error("--element is required for --task=stiffness") + element = load_element(args.element) + operators = element.get_stiffness_operators() + if operators: + models_to_compile = {op.name: op for op in operators} + else: + m = element.get_stiffness_model() + models_to_compile = {m.name: m} + + elif args.task == "mass": + if not args.element: + parser.error("--element is required for --task=mass") + element = load_element(args.element) + operators = element.get_mass_operators() + if operators: + models_to_compile = {op.name: op for op in operators} + else: + # Fallback if no specific mass model is defined, though mass usually has operators + parser.error(f"No mass operators defined for element: {args.element}") + + elif args.task == "custom": + if not args.script: + parser.error("--script is required for --task=custom") + script_path = Path(args.script) + if not script_path.exists(): + parser.error(f"Script file not found: {script_path}") + + # Dynamically load the script + spec = importlib.util.spec_from_file_location("custom_script", str(script_path)) + custom_mod = importlib.util.module_from_spec(spec) + sys.modules["custom_script"] = custom_mod + spec.loader.exec_module(custom_mod) + + if not hasattr(custom_mod, "get_model"): + parser.error(f"Script {script_path} must define a 'get_model()' function.") + + models = custom_mod.get_model() + if type(models).__name__ == "MathModel": + models_to_compile = {models.name: models} + elif isinstance(models, list) and all(type(m).__name__ == "MathModel" for m in models): + models_to_compile = {m.name: m for m in models} + else: + parser.error(f"get_model() must return a MathModel or a list of MathModels. Got: {type(models)}") + + # ---------------- Compile Models ---------------- + base_test_dir = Path(args.test_output_dir or args.output or ".") if args.test else None + + for name, model in models_to_compile.items(): + if args.target == "all": + # --target all: 使用 compile_all 实现真正的共享 CSE + generated = FEACompiler.compile_all( + model, + chunk_size=args.chunk_size, + cse_strategy=args.cse_strategy, + test=args.test, + task=args.task, + model_name=args.material or args.element or name, + ) + for t, code in generated.items(): + if t in ("jax", "cpp", "cuda", "fortran"): + out_path = Path(args.output or ".") / _default_output(name, t) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + f.write(code) + print(f"Generated: {out_path}") + # Also copy kernel source to test directory if --test is enabled + if args.test and base_test_dir is not None: + kernel_dir = base_test_dir / name + kernel_dir.mkdir(parents=True, exist_ok=True) + kernel_ext_map = {"cpp": "kernel.cpp", "cuda": "kernel.cu", "fortran": "kernel.f90"} + if t in kernel_ext_map: + kernel_path = kernel_dir / kernel_ext_map[t] + with open(kernel_path, "w", encoding="utf-8") as f: + f.write(code) + elif args.test and t in ("cpp_wrapper", "f90_wrapper", "test_driver", "build_sh", "build_bat"): + # Each model gets its own subdirectory to avoid overwriting + test_dir = base_test_dir / name + test_dir.mkdir(parents=True, exist_ok=True) + fname_map = { + "cpp_wrapper": "main.cpp", + "f90_wrapper": "main.f90", + "test_driver": "test_driver.py", + "build_sh": "build.sh", + "build_bat": "build.bat", + } + out_path = test_dir / fname_map[t] + with open(out_path, "w", encoding="utf-8") as f: + f.write(code) + print(f"Generated: {out_path}") + # Generate codegen_meta.json for test_driver.py to locate sympy_codegen + if args.test and base_test_dir is not None: + test_dir = base_test_dir / name + test_dir.mkdir(parents=True, exist_ok=True) + code_gen_dir = Path(__file__).resolve().parent.parent # codegen/ -> fea_codegen/ + rel_path = os.path.relpath(code_gen_dir, test_dir.resolve()) + meta = {"code_gen_rel_path": rel_path} + meta_path = test_dir / "codegen_meta.json" + with open(meta_path, "w", encoding="utf-8") as f: + _json.dump(meta, f, indent=2) + print(f"Generated: {meta_path}") + else: + # 单一目标编译 + code = FEACompiler.compile( + model, + args.target, + chunk_size=args.chunk_size, + cse_strategy=args.cse_strategy, + ) + out_path = Path(args.output or ".") / _default_output(name, args.target) + with open(out_path, "w", encoding="utf-8") as f: + f.write(code) + print(f"Generated: {out_path}") diff --git a/codegen/compiler.py b/codegen/compiler.py new file mode 100644 index 0000000..39be9fd --- /dev/null +++ b/codegen/compiler.py @@ -0,0 +1,594 @@ +import re + +import sympy as sp +from sympy.core.relational import Relational +from sympy.printing.numpy import JaxPrinter + +from codegen.model import MathModel +from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter +from codegen.printer import FEACodePrinter, FEAFortranPrinter +from definitions.abc import Element + + +class FEACompiler: + # ========================================================================= + # 公共 Lower 阶段:将 MathModel 转换为 LoweredModel,执行 CSE + # ========================================================================= + @staticmethod + def lower_model(model: MathModel, chunk_size: int) -> LoweredModel: + """执行 CSE lowering,返回可被多个后端共享的 LoweredModel""" + outputs = model.outputs + chunks = [] + + for start in range(0, len(outputs), chunk_size): + chunk_index = start // chunk_size + chunk = outputs[start:start + chunk_size] + sub_exprs, simplified_chunk = sp.cse( + chunk, + symbols=sp.numbered_symbols(f"v_{chunk_index}_") + ) + chunks.append( + LoweredChunk( + chunk_index=chunk_index, + start_index=start, + sub_exprs=sub_exprs, + simplified_outputs=simplified_chunk + ) + ) + + return LoweredModel(model.name, chunk_size, chunks, + external_calls=model.external_calls, + external_ops=model.external_ops) + + # ========================================================================= + # Chunk Size 策略:根据模型规模和目标平台决定 chunk size + # ========================================================================= + @staticmethod + def resolve_chunk_size(model: MathModel, target: str, user_chunk_size=None, strategy="auto") -> int: + """ + 决定 CSE chunk size 的策略。 + + Args: + model: 数学模型 + target: 目标平台 (jax/cpp/cuda/fortran等) + user_chunk_size: 用户通过 CLI 指定的 chunk size + strategy: 策略模式 ("auto" 或 "fixed") + + Returns: + 最终的 chunk size + """ + if user_chunk_size is not None: + return user_chunk_size + + nout = len(model.outputs) + target = target.lower() + + # fixed 模式:使用各后端的固定默认值 + if strategy == "fixed": + if target == "jax": + return 50 + if target in ("cpp", "c++", "cuda", "fortran"): + return 24 + return 24 + + # auto 模式:根据输出规模自动调整 + if strategy == "auto": + if target == "jax": + if nout <= 64: + return 64 + elif nout <= 256: + return 48 + else: + return 32 + + # cpp/cuda/fortran 的自适应策略 + if nout <= 32: + return 32 + elif nout <= 128: + return 24 + elif nout <= 512: + return 16 + else: + return 8 + + raise ValueError(f"Unknown strategy: {strategy}") + + # ========================================================================= + # C++/CUDA 兼容性宏:跨平台支持 GCC/Clang/MSVC/CUDA + # ========================================================================= + @staticmethod + def _cpp_cuda_compat_macros() -> str: + """返回统一的 C++/CUDA 跨平台兼容性宏定义""" + return r""" +#if defined(__CUDACC__) + #define FEA_DEVICE __device__ + #define FEA_HOST __host__ + #define FEA_HOST_DEVICE __host__ __device__ + #define FEA_RESTRICT __restrict__ +#else + #define FEA_DEVICE + #define FEA_HOST + #define FEA_HOST_DEVICE + #if defined(_WIN32) || defined(_WIN64) + #if defined(_MSC_VER) + #define FEA_RESTRICT __restrict + #else + #define FEA_RESTRICT __restrict__ + #endif + #else + #if defined(__GNUC__) || defined(__clang__) + #define FEA_RESTRICT __restrict__ + #else + #define FEA_RESTRICT + #endif + #endif +#endif + +#if defined(_MSC_VER) + #define FEA_ALWAYS_INLINE __forceinline +#elif defined(__GNUC__) || defined(__clang__) + #define FEA_ALWAYS_INLINE inline __attribute__((always_inline)) +#else + #define FEA_ALWAYS_INLINE inline +#endif +""" + + # ========================================================================= + # 核心编译接口 + # ========================================================================= + @staticmethod + def compile(model: MathModel, target: str, chunk_size=None, cse_strategy="auto", lowered=None): + """ + 核心分发器:输入 MathModel + target,输出 cpp/cuda/jax/fortran 源码字符串。 + + Args: + model: 数学模型 + target: 目标平台 ('jax', 'cpp', 'cuda', 'fortran') + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + lowered: 预先 lowered 的结果 (可选,用于多后端共享) + """ + target = target.lower() + if target == 'jax': + return FEACompiler._to_jax(model, lowered=lowered, chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target in ['cpp', 'c++']: + return FEACompiler._to_source(model, is_cuda=False, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'cuda': + return FEACompiler._to_source(model, is_cuda=True, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'fortran': + return FEACompiler._to_fortran(model, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + else: + raise ValueError(f"Unknown target: {target}") + + @staticmethod + def compile_all(model: MathModel, chunk_size=None, cse_strategy="auto", test=False, + task=None, model_name=None): + """ + 一次性生成 jax/cpp/cuda/fortran 四种目标源码。 + + 统一管理 lower 行为: + - 如果所有 target 使用相同的 chunk size,共享一份 lowered + - 如果 JAX 和 cpp/cuda/fortran 使用不同的 chunk size,分别生成 jax_lowered 和 shared_lowered + + Args: + model: 数学模型 + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + test: 是否同时生成测试资产(wrapper、test_driver、build script) + task: CLI 任务类型 ('constitutive', 'stiffness', 'mass', 'custom'),用于 test_driver 重新加载模型 + model_name: 模型/材料/单元名称,用于 test_driver 重新加载模型 + + Returns: + dict: {'jax': code, 'cpp': code, 'cuda': code, 'fortran': code, + 'cpp_wrapper': str, 'f90_wrapper': str, 'test_driver': str, + 'build_sh': str, 'build_bat': str} (后5项仅在 test=True 时存在) + """ + from ci_test.wrappers import generate_cpp_main, generate_f90_main + from ci_test.test_driver_template import generate_test_driver + from ci_test.build_script_generator import generate_build_sh, generate_build_bat + + # 决定各 target 的 chunk size + cpp_chunk = FEACompiler.resolve_chunk_size(model, "cpp", chunk_size, cse_strategy) + jax_chunk = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) + + # 生成 shared lowered 给 cpp/cuda/fortran + shared_lowered = FEACompiler.lower_model(model, cpp_chunk) + + # 决定 JAX 是否共享 lowered + if jax_chunk == cpp_chunk: + jax_lowered = shared_lowered + else: + jax_lowered = FEACompiler.lower_model(model, jax_chunk) + + result = { + "jax": FEACompiler._to_jax(model, lowered=jax_lowered, chunk_size=jax_chunk, cse_strategy=cse_strategy), + "cpp": FEACompiler._to_source(model, is_cuda=False, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + "cuda": FEACompiler._to_source(model, is_cuda=True, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + "fortran": FEACompiler._to_fortran(model, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + } + + if test: + result["cpp_wrapper"] = generate_cpp_main(model) + result["f90_wrapper"] = generate_f90_main(model) + result["test_driver"] = generate_test_driver(model, task=task, model_name=model_name) + result["build_sh"] = generate_build_sh(model) + result["build_bat"] = generate_build_bat(model) + + return result + + @staticmethod + def _to_jax(model, lowered=None, chunk_size=None, cse_strategy="auto"): + """ + 生成 JAX 源码(.py),采用分块 CSE 优化。 + + Args: + model: 数学模型 + lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + """ + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) + lowered = FEACompiler.lower_model(model, chunk_size) + + lines = [ + '"""Generated by sympy_codegen.py. Do not edit."""', + "import jax.numpy as jnp", + "", + "", + f"def compute_{model.name}(in_flat):", + f' """', + f' Compute the {model.name} kernel.', + f' ', + f' Args:', + f' in_flat: Flattened input array, size {len(model.inputs)}', + f' ', + f' Returns:', + f' Flattened output array, size {len(model.outputs)}', + f' ', + f' Input layout:', + ] + + # 添加输入信息 + for i, name in enumerate(model.input_names): + lines.append(f" ' - in_flat[{i}]: {name}") + + lines.append(f" '") + lines.append(f" ' Output layout:") + + # 添加输出信息 + for i, name in enumerate(model.output_names): + lines.append(f" ' - out[{i}]: {name}") + + lines.append(f' """') + + # Unpack inputs IF they are valid identifiers (e.g. xi, c0) + # If they are like "in[0]", we'll handle them via string replacement later + for i, sym in enumerate(model.inputs): + s = str(sym) + is_ident = s.isidentifier() + # print(f"DEBUG: sym={s}, is_ident={is_ident}") + if is_ident: + lines.append(f" {s} = in_flat[{i}]") + + lines.append("") + + printer = CachedPrinter(JaxPrinter()) + all_simplified_outputs = [] + + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + if op.jax_func is None: + raise ValueError( + f"External operator '{call.op_name}' has no JAX implementation. " + f"Cannot generate JAX code for model '{model.name}'." + ) + lines.append(f" # --- External Operator: {call.op_name} ---") + in_parts = ", ".join(printer.doprint(e) for e in call.input_exprs) + lines.append(f" {call.prefix}_in = jnp.array([{in_parts}])") + lines.append(f" {call.prefix}_out = {op.jax_func}({call.prefix}_in)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out[{i}]") + lines.append("") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + for var, expr in chunk.sub_exprs: + lines.append(f" {var} = {printer.doprint(expr)}") + + all_simplified_outputs.extend(chunk.simplified_outputs) + + lines.append("") + lines.append(" # --- Output ---") + out_parts = [printer.doprint(e) for e in all_simplified_outputs] + lines.append(f" return ({','.join(out_parts)})") + + src = "\n".join(lines) + # Final cleanup for JAX and handle C-style inputs + src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") + return src + + @staticmethod + def _to_source(model, is_cuda=False, lowered=None, chunk_size=None, cse_strategy="auto"): + """ + 生成 C++/CUDA 源码,采用分块 CSE 优化及算子化增强。 + + Args: + model: 数学模型 + is_cuda: 是否为 CUDA 目标 + lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + """ + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = FEACompiler.resolve_chunk_size(model, "cuda" if is_cuda else "cpp", + chunk_size, cse_strategy) + lowered = FEACompiler.lower_model(model, chunk_size) + + # --- Generate Comments --- + comment_lines = ["/**"] + comment_lines.append(f" * @brief Computes the {model.name} kernel.") + if model.is_operator: + comment_lines.append(" * @note This is an optimized operator kernel.") + comment_lines.append(" * ") + comment_lines.append(" * @param in Input array (const double*). Layout:") + + for i, name in enumerate(model.input_names): + comment_lines.append(f" * - in[{i}]: {name}") + + comment_lines.append(" * ") + comment_lines.append(" * @param out Output array (double*). Layout:") + + # 列出每个输出的详细信息 + for i, name in enumerate(model.output_names): + comment_lines.append(f" * - out[{i}]: {name}") + + comment_lines.append(" */") + comment_block = "\n".join(comment_lines) + + # --- Generate Function Body --- + body_lines = [] + + # 解包输入变量 + for i, sym in enumerate(model.inputs): + s = str(sym) + # 检查是否是合法标识符(如 coord_2_3),如果是则解包 + if s.isidentifier(): + body_lines.append(f" double {s} = in[{i}];") + + body_lines.append("") + + # 初始化带缓存的 Printer + printer = CachedPrinter(FEACodePrinter()) + + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + body_lines.append(f" // --- External Operator: {call.op_name} ---") + body_lines.append(f" double {call.prefix}_in[{op.n_inputs}];") + body_lines.append(f" double {call.prefix}_out[{op.n_outputs}];") + + for i, expr in enumerate(call.input_exprs): + body_lines.append(f" {call.prefix}_in[{i}] = {printer.doprint(expr)};") + + body_lines.append(f" {op.cpp_func}({call.prefix}_in, {call.prefix}_out);") + + for i, sym in enumerate(call.output_symbols): + body_lines.append(f" double {sym} = {call.prefix}_out[{i}];") + + body_lines.append("") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + body_lines.append(f"\n // --- Chunk {chunk.chunk_index} ---") + + for var, expr in chunk.sub_exprs: + body_lines.append(f" double {var} = {printer.doprint(expr)};") + + for j, out_expr in enumerate(chunk.simplified_outputs): + body_lines.append(f" out[{chunk.start_index + j}] = {printer.doprint(out_expr)};") + + body = "\n".join(body_lines) + + # 统一使用兼容宏体系 + prefix = FEACompiler._cpp_cuda_compat_macros() + "\n" + + if is_cuda: + # CUDA 使用 FEA_DEVICE 宏 + func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + else: + # C++ 使用 FEA_ALWAYS_INLINE 宏 + func_type = "FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + + return f"{prefix}{comment_block}\n{signature} {{ \n{body}\n}}" + + + + @staticmethod + def _to_fortran(model, lowered=None, chunk_size=None, cse_strategy="auto"): + """生成 Fortran 源码,支持分块 CSE 优化。声明和赋值必须分离。""" + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = FEACompiler.resolve_chunk_size(model, "fortran", chunk_size, cse_strategy) + lowered = FEACompiler.lower_model(model, chunk_size) + + printer = CachedPrinter(FEAFortranPrinter()) + + def _fortran_declare(type_decl, vars_list, indent=" "): + """Generate Fortran declaration with line continuation if exceeding 120 chars. + Fortran free-format limit is 132 chars; we use 120 for safety margin. + Continuation uses '&' at end of line and '&' at start of continuation. + The comma separator must appear at the end of the line (before &) + so that the continuation line can start cleanly with the next variable. + """ + if not vars_list: + return [] + max_len = 120 + prefix = f"{indent}{type_decl} :: " + # Try single line first + single_line = prefix + ", ".join(vars_list) + if len(single_line) <= max_len: + return [single_line] + # Split across multiple lines with continuation + # Strategy: each line ends with ", &" (comma before ampersand) + # and continuation lines start with "& " then the next variable + result_lines = [] + current = prefix + first = True + for v in vars_list: + # Check if adding this variable (with separator) would exceed limit + if first: + candidate = current + v + else: + candidate = current + ", " + v + if len(candidate) + 2 > max_len and not first: + # End current line with comma + ampersand for continuation + result_lines.append(current + ", &") + current = f"{indent}& {v}" + first = False + else: + current = candidate + first = False + result_lines.append(current) + return result_lines + + lines = [ + "! Generated by sympy_codegen.py. Do not edit.", + "!", + f"! Subroutine: compute_{model.name}", + "!", + "! Input array layout (in_vec):", + ] + + # 添加输入信息 + for i, name in enumerate(model.input_names): + lines.append(f"! in_vec({i + 1}): {name}") + + lines.append("!") + lines.append("! Output array layout (out_vec):") + + # 添加输出信息 + for i, name in enumerate(model.output_names): + lines.append(f"! out_vec({i + 1}): {name}") + + lines.extend([ + "!", + f"subroutine compute_{model.name}(in_vec, out_vec)", + " implicit none", + f" double precision, intent(in) :: in_vec({len(model.inputs)})", + f" double precision, intent(out) :: out_vec({len(model.outputs)})", + " ! --- Unpack inputs ---", + ]) + + # Unpack input array to named variables + input_vars = [] + for i, sym in enumerate(model.inputs): + s = str(sym) + if s.isidentifier(): + input_vars.append(s) + + # First, declare all input variables (with line continuation if needed) + if input_vars: + lines.extend(_fortran_declare("double precision", input_vars, " ")) + + # Then assign values + for i, sym in enumerate(model.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_vec({i + 1})") + + # 外部算子调用 + if lowered.external_calls: + lines.append("") + lines.append(" ! --- External Operator Calls ---") + # 先声明所有外部算子相关的变量 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.extend(_fortran_declare("double precision", + [f"{call.prefix}_in({op.n_inputs})", f"{call.prefix}_out({op.n_outputs})"], " ")) + out_var_names = [str(sym) for sym in call.output_symbols] + lines.extend(_fortran_declare("double precision", out_var_names, " ")) + + # 然后赋值和调用 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.append(f" ! External Operator: {call.op_name}") + for i, expr in enumerate(call.input_exprs): + lines.append(f" {call.prefix}_in({i + 1}) = {printer.doprint(expr)}") + lines.append(f" call {op.fortran_func}({call.prefix}_in, {call.prefix}_out)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out({i + 1})") + + lines.append("") + + lines.append(" ! --- Local Variables for CSE ---") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + lines.append(f" ! Chunk {chunk.chunk_index}") + if chunk.sub_exprs: + lines.append(" block") + + # Separate variables by type: logical for comparisons, double precision otherwise + dp_vars = [] + log_vars = [] + for var, expr in chunk.sub_exprs: + if isinstance(expr, Relational): + log_vars.append(str(var)) + else: + dp_vars.append(str(var)) + + if dp_vars: + lines.extend(_fortran_declare("double precision", dp_vars, " ")) + if log_vars: + lines.extend(_fortran_declare("logical", log_vars, " ")) + + # Then assign values + for var, expr in chunk.sub_exprs: + lines.append(f" {var} = {printer.doprint(expr)}") + + for j, out_expr in enumerate(chunk.simplified_outputs): + # Fortran arrays are 1-based. + lines.append(f" out_vec({chunk.start_index + j + 1}) = {printer.doprint(out_expr)}") + + if chunk.sub_exprs: + lines.append(" end block") + + lines.append(f"end subroutine compute_{model.name}") + src = "\n".join(lines) + # Replace C-style array access in[i] with Fortran 1-based in_vec(i+1) + # This handles SymPy symbols like in[0], in[1] that are not valid identifiers + def _replace_in_array(m): + idx = int(m.group(1)) + return f"in_vec({idx + 1})" + src = re.sub(r'in\[(\d+)\]', _replace_in_array, src) + return src + + @staticmethod + def compile_element(element: Element, target: str, chunk_size=None, cse_strategy="auto"): + """ + Special compiler for Elements: supports both single-kernel and operator-based generation. + """ + operators = element.get_stiffness_operators() + if operators: + # Generate multiple operator kernels + generated = {} + for op_model in operators: + generated[op_model.name] = FEACompiler.compile(op_model, target, + chunk_size=chunk_size, cse_strategy=cse_strategy) + return generated + else: + # Traditional single kernel + model = element.get_stiffness_model() + return {model.name: FEACompiler.compile(model, target, + chunk_size=chunk_size, cse_strategy=cse_strategy)} diff --git a/codegen/loader.py b/codegen/loader.py new file mode 100644 index 0000000..4d9a88c --- /dev/null +++ b/codegen/loader.py @@ -0,0 +1,24 @@ +import importlib + +from definitions.abc import Element, Material + + +def _load_class(module_path: str, class_name: str): + """动态加载类""" + try: + module = importlib.import_module(module_path) + return getattr(module, class_name) + except (ModuleNotFoundError, AttributeError) as e: + raise ImportError(f"Could not find class '{class_name}' in module '{module_path}'.\nError: {e}") + +def load_element(name: str) -> Element: + """Loads an element class from the definitions.elements directory.""" + class_name = name.capitalize() + module_path = f"definitions.elements.{name}" + return _load_class(module_path, class_name)() + +def load_material(name: str) -> Material: + """Loads a material class from the definitions.materials directory.""" + class_name = name.capitalize() + module_path = f"definitions.materials.{name}" + return _load_class(module_path, class_name)() diff --git a/codegen/lowered.py b/codegen/lowered.py new file mode 100644 index 0000000..e44174a --- /dev/null +++ b/codegen/lowered.py @@ -0,0 +1,52 @@ +class ExternalOperator: + """描述一个外部实现的算子(向量输入、向量输出)。""" + def __init__(self, name, n_inputs, n_outputs, cpp_func=None, fortran_func=None, jax_func=None): + self.name = name # 算子名称,如 "inv12" + self.n_inputs = n_inputs # 输入元素个数,如 144 + self.n_outputs = n_outputs # 输出元素个数,如 144 + self.cpp_func = cpp_func or name # C++ 函数名 + self.fortran_func = fortran_func or name # Fortran 子程序名 + self.jax_func = jax_func # JAX 函数名,None 表示不支持 + + +class ExternalCall: + """记录 MathModel 中一次外部算子调用。""" + def __init__(self, op_name, input_exprs, output_symbols, prefix=None): + self.op_name = op_name # 对应 ExternalOperator 的 name + self.input_exprs = list(input_exprs) # SymPy 表达式列表(输入) + self.output_symbols = list(output_symbols) # SymPy 符号列表(输出占位) + self.prefix = prefix or op_name # 变量名前缀,避免多次调用冲突 + + +class LoweredChunk: + """单个 chunk 的 CSE 结果""" + def __init__(self, chunk_index: int, start_index: int, sub_exprs: list, simplified_outputs: list): + self.chunk_index = chunk_index + self.start_index = start_index + self.sub_exprs = sub_exprs # list of (Symbol, Expr) + self.simplified_outputs = simplified_outputs # list of Expr + + +class LoweredModel: + """整个模型经过 CSE lowering 后的结果""" + def __init__(self, model_name: str, chunk_size: int, chunks: list, + external_calls=None, external_ops=None): + self.model_name = model_name + self.chunk_size = chunk_size + self.chunks = chunks # list of LoweredChunk + self.external_calls = external_calls or [] # list[ExternalCall] + self.external_ops = external_ops or {} # dict[str, ExternalOperator] + + +class CachedPrinter: + """带 memo cache 的 printer 封装器""" + def __init__(self, printer): + self.printer = printer + self.cache = {} + + def doprint(self, expr): + if expr in self.cache: + return self.cache[expr] + result = self.printer.doprint(expr) + self.cache[expr] = result + return result diff --git a/codegen/model.py b/codegen/model.py new file mode 100644 index 0000000..234eff6 --- /dev/null +++ b/codegen/model.py @@ -0,0 +1,49 @@ +import sympy as sp + +from codegen.lowered import ExternalOperator, ExternalCall # noqa: F401 — re-exported via __init__ + + +class MathModel: + """数据容器:存储数学定义""" + def __init__(self, inputs, outputs, name="kernel", input_names=None, output_names=None, + is_operator=False, external_ops=None, external_calls=None): + self.inputs = inputs # SymPy 符号列表 + self.outputs = outputs # SymPy 表达式列表 + self.name = name + self.input_names = input_names or [str(s) for s in inputs] + self.output_names = output_names or [f"out[{i}]" for i in range(len(outputs))] + self.is_operator = is_operator # 是否作为算子生成(可能包含SIMD优化等) + self.external_ops = external_ops or {} # dict[str, ExternalOperator] + self.external_calls = external_calls or [] # list[ExternalCall] + + +def external_call(model, op_name, input_exprs, n_outputs=None, prefix=None): + """ + 在 MathModel 中注册一次外部算子调用,返回输出符号列表。 + + Args: + model: MathModel 实例 + op_name: model.external_ops 中的键名 + input_exprs: 输入表达式列表 (SymPy 表达式) + n_outputs: 输出数量(如不提供则从 external_ops 查找) + prefix: 输出符号前缀(默认等于 op_name,多次调用同一算子时需不同前缀) + + Returns: + list[sp.Symbol]: 输出占位符号列表,可用于后续表达式 + """ + op = model.external_ops[op_name] + n_outputs = n_outputs or op.n_outputs + prefix = prefix or op_name + + output_symbols = list(sp.symbols(f"{prefix}_0:{n_outputs}")) + + model.external_calls.append( + ExternalCall( + op_name=op_name, + input_exprs=list(input_exprs), + output_symbols=output_symbols, + prefix=prefix, + ) + ) + + return output_symbols diff --git a/codegen/printer.py b/codegen/printer.py new file mode 100644 index 0000000..dd89b67 --- /dev/null +++ b/codegen/printer.py @@ -0,0 +1,133 @@ +from sympy.printing.c import C99CodePrinter +from sympy.printing.fortran import FCodePrinter + + +class FEACodePrinter(C99CodePrinter): + """ + 专门为有限元计算优化的代码打印机: + 1. 展开低次幂 pow(x, 2) -> (x*x) + 2. 优化倒数 pow(x, -1) -> (1.0/(x)) + 3. 优化平方根和立方根,及分数次幂组合 + """ + def _print_Pow(self, expr): + base, exp = expr.as_base_exp() + s_base = self._print(base) + + # 处理整数幂 (2, 3, -1, -2) + if exp.is_Integer: + if exp == 2: + return f"({s_base} * {s_base})" + elif exp == 3: + return f"({s_base} * {s_base} * {s_base})" + elif exp == -1: + return f"(1.0 / ({s_base}))" + elif exp == -2: + return f"(1.0 / ({s_base} * {s_base}))" + + # 处理分数幂,尽量转化为 sqrt 和 cbrt 的乘除法 + # 注意:这里使用浮点比较处理 SymPy 的 Rational 或 Float + try: + val = float(exp) + except TypeError: + return super()._print_Pow(expr) + + # 1/2 系列 (sqrt) + if abs(val - 0.5) < 1e-9: + return f"sqrt({s_base})" + if abs(val + 0.5) < 1e-9: + return f"(1.0 / sqrt({s_base}))" + + # 1/3 系列 (cbrt) + if abs(val - 1.0/3.0) < 1e-7: + return f"cbrt({s_base})" + if abs(val + 1.0/3.0) < 1e-7: + return f"(1.0 / cbrt({s_base}))" + + # 2/3 系列 + if abs(val - 2.0/3.0) < 1e-7: + return f"(cbrt({s_base}) * cbrt({s_base}))" + if abs(val + 2.0/3.0) < 1e-7: + return f"(1.0 / (cbrt({s_base}) * cbrt({s_base})))" + + # 5/6 系列 (5/6 = 1/2 + 1/3) + if abs(val - 5.0/6.0) < 1e-7: + return f"(sqrt({s_base}) * cbrt({s_base}))" + if abs(val + 5.0/6.0) < 1e-7: + return f"(1.0 / (sqrt({s_base}) * cbrt({s_base})))" + + # 其他情况回退到标准 pow + return super()._print_Pow(expr) + + +class FEAFortranPrinter(FCodePrinter): + """ + Fortran 90/95+ optimized code printer: + 1. Enforce double precision constants (1.0 -> 1.0d0) + 2. Expand low-order powers to help vectorization + 3. Optimize fractional powers + """ + def __init__(self, settings=None): + settings = settings or {} + settings.update({"standard": 95, "source_format": "free"}) + super().__init__(settings) + + def _print_Piecewise(self, expr): + # Ensure integer default values in Piecewise are printed as double precision + # to avoid type mismatch in the Fortran merge() intrinsic. + if expr.args[-1].cond == True: + default = expr.args[-1].expr + if default.is_Integer: + result = f"{float(default)}d0" + else: + result = self._print(default) + for e, c in reversed(expr.args[:-1]): + result = "merge(%s, %s, %s)" % (self._print(e), result, self._print(c)) + return result + return super()._print_Piecewise(expr) + + def _print_Float(self, expr): + # Keep all floating constants in double precision. + res = super()._print_Float(expr) + return res + "d0" if "d" not in res.lower() and "e" not in res.lower() else res + + def _print_Pow(self, expr): + base, exp = expr.as_base_exp() + s_base = self._print(base) + s_exp = self._print(exp) + + # Integer powers + if exp.is_Integer: + if exp == 2: + return f"({s_base} * {s_base})" + if exp == 3: + return f"({s_base} * {s_base} * {s_base})" + if exp == -1: + return f"(1.0d0 / ({s_base}))" + if exp == -2: + return f"(1.0d0 / ({s_base} * {s_base}))" + + try: + val = float(exp) + except TypeError: + return super()._print_Pow(expr) + + # Fractional powers + if abs(val - 0.5) < 1e-9: + return f"sqrt({s_base})" + if abs(val + 0.5) < 1e-9: + return f"(1.0d0 / sqrt({s_base}))" + if abs(val - 1.0 / 3.0) < 1e-7: + return f"({s_base}**(1.0d0/3.0d0))" + if abs(val + 1.0 / 3.0) < 1e-7: + return f"(1.0d0 / ({s_base}**(1.0d0/3.0d0)))" + if abs(val - 2.0 / 3.0) < 1e-7: + return f"({s_base}**(2.0d0/3.0d0))" + if abs(val + 2.0 / 3.0) < 1e-7: + return f"(1.0d0 / ({s_base}**(2.0d0/3.0d0)))" + if abs(val - 5.0 / 6.0) < 1e-7: + return f"({s_base}**(5.0d0/6.0d0))" + if abs(val + 5.0 / 6.0) < 1e-7: + return f"(1.0d0 / ({s_base}**(5.0d0/6.0d0)))" + + # Parenthesize the exponent to preserve precedence in Fortran. + return f"({s_base}**({s_exp}))" diff --git a/sympy_codegen.py b/sympy_codegen.py index 41af213..1dc183c 100644 --- a/sympy_codegen.py +++ b/sympy_codegen.py @@ -1,1060 +1,13 @@ -import sympy as sp -from sympy.core.relational import Relational -from sympy.printing.c import C99CodePrinter -from sympy.printing.fortran import FCodePrinter -from sympy.printing.numpy import JaxPrinter -import argparse -import os -import re -import sys -import importlib.util -import importlib -from pathlib import Path +"""兼容性入口 — 所有公开符号从 codegen 包 re-export,保证 from sympy_codegen import ... 仍然可用。""" + +from codegen import * # noqa: F401,F403 +from codegen.cli import main # noqa: F401 — CLI 入口单独导出 # Add the project root to the Python path to allow finding the 'definitions' module +import sys +from pathlib import Path sys.path.append(str(Path(__file__).parent.resolve())) -from definitions.abc import Element, Material -from ci_test.wrappers import generate_cpp_main, generate_f90_main -from ci_test.test_driver_template import generate_test_driver -from ci_test.build_script_generator import generate_build_sh, generate_build_bat -from ci_test.ci_workflow_generator import generate_github_actions_workflow - - -# --------------------------------------------------------------------------- -# 辅助数据结构:用于 CSE 结果缓存和跨后端共享 -# --------------------------------------------------------------------------- -class ExternalOperator: - """描述一个外部实现的算子(向量输入、向量输出)。""" - def __init__(self, name, n_inputs, n_outputs, cpp_func=None, fortran_func=None, jax_func=None): - self.name = name # 算子名称,如 "inv12" - self.n_inputs = n_inputs # 输入元素个数,如 144 - self.n_outputs = n_outputs # 输出元素个数,如 144 - self.cpp_func = cpp_func or name # C++ 函数名 - self.fortran_func = fortran_func or name # Fortran 子程序名 - self.jax_func = jax_func # JAX 函数名,None 表示不支持 - - -class ExternalCall: - """记录 MathModel 中一次外部算子调用。""" - def __init__(self, op_name, input_exprs, output_symbols, prefix=None): - self.op_name = op_name # 对应 ExternalOperator 的 name - self.input_exprs = list(input_exprs) # SymPy 表达式列表(输入) - self.output_symbols = list(output_symbols) # SymPy 符号列表(输出占位) - self.prefix = prefix or op_name # 变量名前缀,避免多次调用冲突 - - -class LoweredChunk: - """单个 chunk 的 CSE 结果""" - def __init__(self, chunk_index: int, start_index: int, sub_exprs: list, simplified_outputs: list): - self.chunk_index = chunk_index - self.start_index = start_index - self.sub_exprs = sub_exprs # list of (Symbol, Expr) - self.simplified_outputs = simplified_outputs # list of Expr - - -class LoweredModel: - """整个模型经过 CSE lowering 后的结果""" - def __init__(self, model_name: str, chunk_size: int, chunks: list, - external_calls=None, external_ops=None): - self.model_name = model_name - self.chunk_size = chunk_size - self.chunks = chunks # list of LoweredChunk - self.external_calls = external_calls or [] # list[ExternalCall] - self.external_ops = external_ops or {} # dict[str, ExternalOperator] - - -class CachedPrinter: - """带 memo cache 的 printer 封装器""" - def __init__(self, printer): - self.printer = printer - self.cache = {} - - def doprint(self, expr): - if expr in self.cache: - return self.cache[expr] - result = self.printer.doprint(expr) - self.cache[expr] = result - return result - - -# --------------------------------------------------------------------------- -# 数据容器 + 静态编译分发 -# --------------------------------------------------------------------------- -class FEACodePrinter(C99CodePrinter): - """ - 专门为有限元计算优化的代码打印机: - 1. 展开低次幂 pow(x, 2) -> (x*x) - 2. 优化倒数 pow(x, -1) -> (1.0/(x)) - 3. 优化平方根和立方根,及分数次幂组合 - """ - def _print_Pow(self, expr): - base, exp = expr.as_base_exp() - s_base = self._print(base) - - # 处理整数幂 (2, 3, -1, -2) - if exp.is_Integer: - if exp == 2: - return f"({s_base} * {s_base})" - elif exp == 3: - return f"({s_base} * {s_base} * {s_base})" - elif exp == -1: - return f"(1.0 / ({s_base}))" - elif exp == -2: - return f"(1.0 / ({s_base} * {s_base}))" - - # 处理分数幂,尽量转化为 sqrt 和 cbrt 的乘除法 - # 注意:这里使用浮点比较处理 SymPy 的 Rational 或 Float - try: - val = float(exp) - except TypeError: - return super()._print_Pow(expr) - - # 1/2 系列 (sqrt) - if abs(val - 0.5) < 1e-9: - return f"sqrt({s_base})" - if abs(val + 0.5) < 1e-9: - return f"(1.0 / sqrt({s_base}))" - - # 1/3 系列 (cbrt) - if abs(val - 1.0/3.0) < 1e-7: - return f"cbrt({s_base})" - if abs(val + 1.0/3.0) < 1e-7: - return f"(1.0 / cbrt({s_base}))" - - # 2/3 系列 - if abs(val - 2.0/3.0) < 1e-7: - return f"(cbrt({s_base}) * cbrt({s_base}))" - if abs(val + 2.0/3.0) < 1e-7: - return f"(1.0 / (cbrt({s_base}) * cbrt({s_base})))" - - # 5/6 系列 (5/6 = 1/2 + 1/3) - if abs(val - 5.0/6.0) < 1e-7: - return f"(sqrt({s_base}) * cbrt({s_base}))" - if abs(val + 5.0/6.0) < 1e-7: - return f"(1.0 / (sqrt({s_base}) * cbrt({s_base})))" - - # 其他情况回退到标准 pow - return super()._print_Pow(expr) - - -class FEAFortranPrinter(FCodePrinter): - """ - Fortran 90/95+ optimized code printer: - 1. Enforce double precision constants (1.0 -> 1.0d0) - 2. Expand low-order powers to help vectorization - 3. Optimize fractional powers - """ - def __init__(self, settings=None): - settings = settings or {} - settings.update({"standard": 95, "source_format": "free"}) - super().__init__(settings) - - def _print_Piecewise(self, expr): - # Ensure integer default values in Piecewise are printed as double precision - # to avoid type mismatch in the Fortran merge() intrinsic. - if expr.args[-1].cond == True: - default = expr.args[-1].expr - if default.is_Integer: - result = f"{float(default)}d0" - else: - result = self._print(default) - for e, c in reversed(expr.args[:-1]): - result = "merge(%s, %s, %s)" % (self._print(e), result, self._print(c)) - return result - return super()._print_Piecewise(expr) - - def _print_Float(self, expr): - # Keep all floating constants in double precision. - res = super()._print_Float(expr) - return res + "d0" if "d" not in res.lower() and "e" not in res.lower() else res - - def _print_Pow(self, expr): - base, exp = expr.as_base_exp() - s_base = self._print(base) - s_exp = self._print(exp) - - # Integer powers - if exp.is_Integer: - if exp == 2: - return f"({s_base} * {s_base})" - if exp == 3: - return f"({s_base} * {s_base} * {s_base})" - if exp == -1: - return f"(1.0d0 / ({s_base}))" - if exp == -2: - return f"(1.0d0 / ({s_base} * {s_base}))" - - try: - val = float(exp) - except TypeError: - return super()._print_Pow(expr) - - # Fractional powers - if abs(val - 0.5) < 1e-9: - return f"sqrt({s_base})" - if abs(val + 0.5) < 1e-9: - return f"(1.0d0 / sqrt({s_base}))" - if abs(val - 1.0 / 3.0) < 1e-7: - return f"({s_base}**(1.0d0/3.0d0))" - if abs(val + 1.0 / 3.0) < 1e-7: - return f"(1.0d0 / ({s_base}**(1.0d0/3.0d0)))" - if abs(val - 2.0 / 3.0) < 1e-7: - return f"({s_base}**(2.0d0/3.0d0))" - if abs(val + 2.0 / 3.0) < 1e-7: - return f"(1.0d0 / ({s_base}**(2.0d0/3.0d0)))" - if abs(val - 5.0 / 6.0) < 1e-7: - return f"({s_base}**(5.0d0/6.0d0))" - if abs(val + 5.0 / 6.0) < 1e-7: - return f"(1.0d0 / ({s_base}**(5.0d0/6.0d0)))" - - # Parenthesize the exponent to preserve precedence in Fortran. - return f"({s_base}**({s_exp}))" - - -class MathModel: - """数据容器:存储数学定义""" - def __init__(self, inputs, outputs, name="kernel", input_names=None, output_names=None, - is_operator=False, external_ops=None, external_calls=None): - self.inputs = inputs # SymPy 符号列表 - self.outputs = outputs # SymPy 表达式列表 - self.name = name - self.input_names = input_names or [str(s) for s in inputs] - self.output_names = output_names or [f"out[{i}]" for i in range(len(outputs))] - self.is_operator = is_operator # 是否作为算子生成(可能包含SIMD优化等) - self.external_ops = external_ops or {} # dict[str, ExternalOperator] - self.external_calls = external_calls or [] # list[ExternalCall] - - -def external_call(model, op_name, input_exprs, n_outputs=None, prefix=None): - """ - 在 MathModel 中注册一次外部算子调用,返回输出符号列表。 - - Args: - model: MathModel 实例 - op_name: model.external_ops 中的键名 - input_exprs: 输入表达式列表 (SymPy 表达式) - n_outputs: 输出数量(如不提供则从 external_ops 查找) - prefix: 输出符号前缀(默认等于 op_name,多次调用同一算子时需不同前缀) - - Returns: - list[sp.Symbol]: 输出占位符号列表,可用于后续表达式 - """ - op = model.external_ops[op_name] - n_outputs = n_outputs or op.n_outputs - prefix = prefix or op_name - - output_symbols = list(sp.symbols(f"{prefix}_0:{n_outputs}")) - - model.external_calls.append( - ExternalCall( - op_name=op_name, - input_exprs=list(input_exprs), - output_symbols=output_symbols, - prefix=prefix, - ) - ) - - return output_symbols - - -class FEACompiler: - # ========================================================================= - # 公共 Lower 阶段:将 MathModel 转换为 LoweredModel,执行 CSE - # ========================================================================= - @staticmethod - def lower_model(model: MathModel, chunk_size: int) -> LoweredModel: - """执行 CSE lowering,返回可被多个后端共享的 LoweredModel""" - outputs = model.outputs - chunks = [] - - for start in range(0, len(outputs), chunk_size): - chunk_index = start // chunk_size - chunk = outputs[start:start + chunk_size] - sub_exprs, simplified_chunk = sp.cse( - chunk, - symbols=sp.numbered_symbols(f"v_{chunk_index}_") - ) - chunks.append( - LoweredChunk( - chunk_index=chunk_index, - start_index=start, - sub_exprs=sub_exprs, - simplified_outputs=simplified_chunk - ) - ) - - return LoweredModel(model.name, chunk_size, chunks, - external_calls=model.external_calls, - external_ops=model.external_ops) - - # ========================================================================= - # Chunk Size 策略:根据模型规模和目标平台决定 chunk size - # ========================================================================= - @staticmethod - def resolve_chunk_size(model: MathModel, target: str, user_chunk_size=None, strategy="auto") -> int: - """ - 决定 CSE chunk size 的策略。 - - Args: - model: 数学模型 - target: 目标平台 (jax/cpp/cuda/fortran等) - user_chunk_size: 用户通过 CLI 指定的 chunk size - strategy: 策略模式 ("auto" 或 "fixed") - - Returns: - 最终的 chunk size - """ - if user_chunk_size is not None: - return user_chunk_size - - nout = len(model.outputs) - target = target.lower() - - # fixed 模式:使用各后端的固定默认值 - if strategy == "fixed": - if target == "jax": - return 50 - if target in ("cpp", "c++", "cuda", "fortran"): - return 24 - return 24 - - # auto 模式:根据输出规模自动调整 - if strategy == "auto": - if target == "jax": - if nout <= 64: - return 64 - elif nout <= 256: - return 48 - else: - return 32 - - # cpp/cuda/fortran 的自适应策略 - if nout <= 32: - return 32 - elif nout <= 128: - return 24 - elif nout <= 512: - return 16 - else: - return 8 - - raise ValueError(f"Unknown strategy: {strategy}") - - # ========================================================================= - # C++/CUDA 兼容性宏:跨平台支持 GCC/Clang/MSVC/CUDA - # ========================================================================= - @staticmethod - def _cpp_cuda_compat_macros() -> str: - """返回统一的 C++/CUDA 跨平台兼容性宏定义""" - return r""" -#if defined(__CUDACC__) - #define FEA_DEVICE __device__ - #define FEA_HOST __host__ - #define FEA_HOST_DEVICE __host__ __device__ - #define FEA_RESTRICT __restrict__ -#else - #define FEA_DEVICE - #define FEA_HOST - #define FEA_HOST_DEVICE - #if defined(_WIN32) || defined(_WIN64) - #if defined(_MSC_VER) - #define FEA_RESTRICT __restrict - #else - #define FEA_RESTRICT __restrict__ - #endif - #else - #if defined(__GNUC__) || defined(__clang__) - #define FEA_RESTRICT __restrict__ - #else - #define FEA_RESTRICT - #endif - #endif -#endif - -#if defined(_MSC_VER) - #define FEA_ALWAYS_INLINE __forceinline -#elif defined(__GNUC__) || defined(__clang__) - #define FEA_ALWAYS_INLINE inline __attribute__((always_inline)) -#else - #define FEA_ALWAYS_INLINE inline -#endif -""" - - # ========================================================================= - # 核心编译接口 - # ========================================================================= - @staticmethod - def compile(model: MathModel, target: str, chunk_size=None, cse_strategy="auto", lowered=None): - """ - 核心分发器:输入 MathModel + target,输出 cpp/cuda/jax/fortran 源码字符串。 - - Args: - model: 数学模型 - target: 目标平台 ('jax', 'cpp', 'cuda', 'fortran') - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - lowered: 预先 lowered 的结果 (可选,用于多后端共享) - """ - target = target.lower() - if target == 'jax': - return FEACompiler._to_jax(model, lowered=lowered, chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target in ['cpp', 'c++']: - return FEACompiler._to_source(model, is_cuda=False, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'cuda': - return FEACompiler._to_source(model, is_cuda=True, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'fortran': - return FEACompiler._to_fortran(model, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - else: - raise ValueError(f"Unknown target: {target}") - - @staticmethod - def compile_all(model: MathModel, chunk_size=None, cse_strategy="auto", test=False, - task=None, model_name=None): - """ - 一次性生成 jax/cpp/cuda/fortran 四种目标源码。 - - 统一管理 lower 行为: - - 如果所有 target 使用相同的 chunk size,共享一份 lowered - - 如果 JAX 和 cpp/cuda/fortran 使用不同的 chunk size,分别生成 jax_lowered 和 shared_lowered - - Args: - model: 数学模型 - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - test: 是否同时生成测试资产(wrapper、test_driver、build script) - task: CLI 任务类型 ('constitutive', 'stiffness', 'mass', 'custom'),用于 test_driver 重新加载模型 - model_name: 模型/材料/单元名称,用于 test_driver 重新加载模型 - - Returns: - dict: {'jax': code, 'cpp': code, 'cuda': code, 'fortran': code, - 'cpp_wrapper': str, 'f90_wrapper': str, 'test_driver': str, - 'build_sh': str, 'build_bat': str} (后5项仅在 test=True 时存在) - """ - # 决定各 target 的 chunk size - cpp_chunk = FEACompiler.resolve_chunk_size(model, "cpp", chunk_size, cse_strategy) - jax_chunk = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) - - # 生成 shared lowered 给 cpp/cuda/fortran - shared_lowered = FEACompiler.lower_model(model, cpp_chunk) - - # 决定 JAX 是否共享 lowered - if jax_chunk == cpp_chunk: - jax_lowered = shared_lowered - else: - jax_lowered = FEACompiler.lower_model(model, jax_chunk) - - result = { - "jax": FEACompiler._to_jax(model, lowered=jax_lowered, chunk_size=jax_chunk, cse_strategy=cse_strategy), - "cpp": FEACompiler._to_source(model, is_cuda=False, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - "cuda": FEACompiler._to_source(model, is_cuda=True, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - "fortran": FEACompiler._to_fortran(model, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - } - - if test: - result["cpp_wrapper"] = generate_cpp_main(model) - result["f90_wrapper"] = generate_f90_main(model) - result["test_driver"] = generate_test_driver(model, task=task, model_name=model_name) - result["build_sh"] = generate_build_sh(model) - result["build_bat"] = generate_build_bat(model) - - return result - - @staticmethod - def _to_jax(model, lowered=None, chunk_size=None, cse_strategy="auto"): - """ - 生成 JAX 源码(.py),采用分块 CSE 优化。 - - Args: - model: 数学模型 - lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - """ - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - lines = [ - '"""Generated by sympy_codegen.py. Do not edit."""', - "import jax.numpy as jnp", - "", - "", - f"def compute_{model.name}(in_flat):", - f' """', - f' Compute the {model.name} kernel.', - f' ', - f' Args:', - f' in_flat: Flattened input array, size {len(model.inputs)}', - f' ', - f' Returns:', - f' Flattened output array, size {len(model.outputs)}', - f' ', - f' Input layout:', - ] - - # 添加输入信息 - for i, name in enumerate(model.input_names): - lines.append(f" ' - in_flat[{i}]: {name}") - - lines.append(f" '") - lines.append(f" ' Output layout:") - - # 添加输出信息 - for i, name in enumerate(model.output_names): - lines.append(f" ' - out[{i}]: {name}") - - lines.append(f' """') - - # Unpack inputs IF they are valid identifiers (e.g. xi, c0) - # If they are like "in[0]", we'll handle them via string replacement later - for i, sym in enumerate(model.inputs): - s = str(sym) - is_ident = s.isidentifier() - # print(f"DEBUG: sym={s}, is_ident={is_ident}") - if is_ident: - lines.append(f" {s} = in_flat[{i}]") - - lines.append("") - - printer = CachedPrinter(JaxPrinter()) - all_simplified_outputs = [] - - # 外部算子调用 - if lowered.external_calls: - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - if op.jax_func is None: - raise ValueError( - f"External operator '{call.op_name}' has no JAX implementation. " - f"Cannot generate JAX code for model '{model.name}'." - ) - lines.append(f" # --- External Operator: {call.op_name} ---") - in_parts = ", ".join(printer.doprint(e) for e in call.input_exprs) - lines.append(f" {call.prefix}_in = jnp.array([{in_parts}])") - lines.append(f" {call.prefix}_out = {op.jax_func}({call.prefix}_in)") - for i, sym in enumerate(call.output_symbols): - lines.append(f" {sym} = {call.prefix}_out[{i}]") - lines.append("") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - for var, expr in chunk.sub_exprs: - lines.append(f" {var} = {printer.doprint(expr)}") - - all_simplified_outputs.extend(chunk.simplified_outputs) - - lines.append("") - lines.append(" # --- Output ---") - out_parts = [printer.doprint(e) for e in all_simplified_outputs] - lines.append(f" return ({','.join(out_parts)})") - - src = "\n".join(lines) - # Final cleanup for JAX and handle C-style inputs - src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") - return src - - @staticmethod - def _to_source(model, is_cuda=False, lowered=None, chunk_size=None, cse_strategy="auto"): - """ - 生成 C++/CUDA 源码,采用分块 CSE 优化及算子化增强。 - - Args: - model: 数学模型 - is_cuda: 是否为 CUDA 目标 - lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - """ - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "cuda" if is_cuda else "cpp", - chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - # --- Generate Comments --- - comment_lines = ["/**"] - comment_lines.append(f" * @brief Computes the {model.name} kernel.") - if model.is_operator: - comment_lines.append(" * @note This is an optimized operator kernel.") - comment_lines.append(" * ") - comment_lines.append(" * @param in Input array (const double*). Layout:") - - for i, name in enumerate(model.input_names): - comment_lines.append(f" * - in[{i}]: {name}") - - comment_lines.append(" * ") - comment_lines.append(" * @param out Output array (double*). Layout:") - - # 列出每个输出的详细信息 - for i, name in enumerate(model.output_names): - comment_lines.append(f" * - out[{i}]: {name}") - - comment_lines.append(" */") - comment_block = "\n".join(comment_lines) - - # --- Generate Function Body --- - body_lines = [] - - # 解包输入变量 - for i, sym in enumerate(model.inputs): - s = str(sym) - # 检查是否是合法标识符(如 coord_2_3),如果是则解包 - if s.isidentifier(): - body_lines.append(f" double {s} = in[{i}];") - - body_lines.append("") - - # 初始化带缓存的 Printer - printer = CachedPrinter(FEACodePrinter()) - - # 外部算子调用 - if lowered.external_calls: - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - body_lines.append(f" // --- External Operator: {call.op_name} ---") - body_lines.append(f" double {call.prefix}_in[{op.n_inputs}];") - body_lines.append(f" double {call.prefix}_out[{op.n_outputs}];") - - for i, expr in enumerate(call.input_exprs): - body_lines.append(f" {call.prefix}_in[{i}] = {printer.doprint(expr)};") - - body_lines.append(f" {op.cpp_func}({call.prefix}_in, {call.prefix}_out);") - - for i, sym in enumerate(call.output_symbols): - body_lines.append(f" double {sym} = {call.prefix}_out[{i}];") - - body_lines.append("") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - body_lines.append(f"\n // --- Chunk {chunk.chunk_index} ---") - - for var, expr in chunk.sub_exprs: - body_lines.append(f" double {var} = {printer.doprint(expr)};") - - for j, out_expr in enumerate(chunk.simplified_outputs): - body_lines.append(f" out[{chunk.start_index + j}] = {printer.doprint(out_expr)};") - - body = "\n".join(body_lines) - - # 统一使用兼容宏体系 - prefix = FEACompiler._cpp_cuda_compat_macros() + "\n" - - if is_cuda: - # CUDA 使用 FEA_DEVICE 宏 - func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" - signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" - else: - # C++ 使用 FEA_ALWAYS_INLINE 宏 - func_type = "FEA_ALWAYS_INLINE void" - signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" - - return f"{prefix}{comment_block}\n{signature} {{ \n{body}\n}}" - - - - @staticmethod - def _to_fortran(model, lowered=None, chunk_size=None, cse_strategy="auto"): - """生成 Fortran 源码,支持分块 CSE 优化。声明和赋值必须分离。""" - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "fortran", chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - printer = CachedPrinter(FEAFortranPrinter()) - - def _fortran_declare(type_decl, vars_list, indent=" "): - """Generate Fortran declaration with line continuation if exceeding 120 chars. - Fortran free-format limit is 132 chars; we use 120 for safety margin. - Continuation uses '&' at end of line and '&' at start of continuation. - The comma separator must appear at the end of the line (before &) - so that the continuation line can start cleanly with the next variable. - """ - if not vars_list: - return [] - max_len = 120 - prefix = f"{indent}{type_decl} :: " - # Try single line first - single_line = prefix + ", ".join(vars_list) - if len(single_line) <= max_len: - return [single_line] - # Split across multiple lines with continuation - # Strategy: each line ends with ", &" (comma before ampersand) - # and continuation lines start with "& " then the next variable - result_lines = [] - current = prefix - first = True - for v in vars_list: - # Check if adding this variable (with separator) would exceed limit - if first: - candidate = current + v - else: - candidate = current + ", " + v - if len(candidate) + 2 > max_len and not first: - # End current line with comma + ampersand for continuation - result_lines.append(current + ", &") - current = f"{indent}& {v}" - first = False - else: - current = candidate - first = False - result_lines.append(current) - return result_lines - - lines = [ - "! Generated by sympy_codegen.py. Do not edit.", - "!", - f"! Subroutine: compute_{model.name}", - "!", - "! Input array layout (in_vec):", - ] - - # 添加输入信息 - for i, name in enumerate(model.input_names): - lines.append(f"! in_vec({i + 1}): {name}") - - lines.append("!") - lines.append("! Output array layout (out_vec):") - - # 添加输出信息 - for i, name in enumerate(model.output_names): - lines.append(f"! out_vec({i + 1}): {name}") - - lines.extend([ - "!", - f"subroutine compute_{model.name}(in_vec, out_vec)", - " implicit none", - f" double precision, intent(in) :: in_vec({len(model.inputs)})", - f" double precision, intent(out) :: out_vec({len(model.outputs)})", - " ! --- Unpack inputs ---", - ]) - - # Unpack input array to named variables - input_vars = [] - for i, sym in enumerate(model.inputs): - s = str(sym) - if s.isidentifier(): - input_vars.append(s) - - # First, declare all input variables (with line continuation if needed) - if input_vars: - lines.extend(_fortran_declare("double precision", input_vars, " ")) - - # Then assign values - for i, sym in enumerate(model.inputs): - s = str(sym) - if s.isidentifier(): - lines.append(f" {s} = in_vec({i + 1})") - - # 外部算子调用 - if lowered.external_calls: - lines.append("") - lines.append(" ! --- External Operator Calls ---") - # 先声明所有外部算子相关的变量 - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - lines.extend(_fortran_declare("double precision", - [f"{call.prefix}_in({op.n_inputs})", f"{call.prefix}_out({op.n_outputs})"], " ")) - out_var_names = [str(sym) for sym in call.output_symbols] - lines.extend(_fortran_declare("double precision", out_var_names, " ")) - - # 然后赋值和调用 - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - lines.append(f" ! External Operator: {call.op_name}") - for i, expr in enumerate(call.input_exprs): - lines.append(f" {call.prefix}_in({i + 1}) = {printer.doprint(expr)}") - lines.append(f" call {op.fortran_func}({call.prefix}_in, {call.prefix}_out)") - for i, sym in enumerate(call.output_symbols): - lines.append(f" {sym} = {call.prefix}_out({i + 1})") - - lines.append("") - - lines.append(" ! --- Local Variables for CSE ---") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - lines.append(f" ! Chunk {chunk.chunk_index}") - if chunk.sub_exprs: - lines.append(" block") - - # Separate variables by type: logical for comparisons, double precision otherwise - dp_vars = [] - log_vars = [] - for var, expr in chunk.sub_exprs: - if isinstance(expr, Relational): - log_vars.append(str(var)) - else: - dp_vars.append(str(var)) - - if dp_vars: - lines.extend(_fortran_declare("double precision", dp_vars, " ")) - if log_vars: - lines.extend(_fortran_declare("logical", log_vars, " ")) - - # Then assign values - for var, expr in chunk.sub_exprs: - lines.append(f" {var} = {printer.doprint(expr)}") - - for j, out_expr in enumerate(chunk.simplified_outputs): - # Fortran arrays are 1-based. - lines.append(f" out_vec({chunk.start_index + j + 1}) = {printer.doprint(out_expr)}") - - if chunk.sub_exprs: - lines.append(" end block") - - lines.append(f"end subroutine compute_{model.name}") - src = "\n".join(lines) - # Replace C-style array access in[i] with Fortran 1-based in_vec(i+1) - # This handles SymPy symbols like in[0], in[1] that are not valid identifiers - def _replace_in_array(m): - idx = int(m.group(1)) - return f"in_vec({idx + 1})" - src = re.sub(r'in\[(\d+)\]', _replace_in_array, src) - return src - - @staticmethod - def compile_element(element: Element, target: str, chunk_size=None, cse_strategy="auto"): - """ - Special compiler for Elements: supports both single-kernel and operator-based generation. - """ - operators = element.get_stiffness_operators() - if operators: - # Generate multiple operator kernels - generated = {} - for op_model in operators: - generated[op_model.name] = FEACompiler.compile(op_model, target, - chunk_size=chunk_size, cse_strategy=cse_strategy) - return generated - else: - # Traditional single kernel - model = element.get_stiffness_model() - return {model.name: FEACompiler.compile(model, target, - chunk_size=chunk_size, cse_strategy=cse_strategy)} - - -# --------------------------------------------------------------------------- -# 动态模型加载 -# --------------------------------------------------------------------------- -def _load_class(module_path: str, class_name: str): - """动态加载类""" - try: - module = importlib.import_module(module_path) - return getattr(module, class_name) - except (ModuleNotFoundError, AttributeError) as e: - raise ImportError(f"Could not find class '{class_name}' in module '{module_path}'.\nError: {e}") - -def load_element(name: str) -> Element: - """Loads an element class from the definitions.elements directory.""" - class_name = name.capitalize() - module_path = f"definitions.elements.{name}" - return _load_class(module_path, class_name)() - -def load_material(name: str) -> Material: - """Loads a material class from the definitions.materials directory.""" - class_name = name.capitalize() - module_path = f"definitions.materials.{name}" - return _load_class(module_path, class_name)() - - -def _default_output(model_name: str, target: str) -> str: - t = target.lower() - if t == "jax": - return f"{model_name}_gen.py" - if t in ("cpp", "c++"): - return f"{model_name}_gen.cpp" - if t == "cuda": - return f"{model_name}_gen.cu" - if t == "fortran": - return f"{model_name}_gen.f90" - return f"{model_name}_{t}.txt" - -def main(): - parser = argparse.ArgumentParser( - description="SymPy FEA 代码生成器 (混合解耦架构)" - ) - parser.add_argument( - "--task", - required=True, - choices=["constitutive", "stiffness", "mass", "custom"], - help="生成任务: 'constitutive' (材料D矩阵), 'stiffness' (单元Ke矩阵), 'mass' (质量矩阵), 或 'custom' (自定义数学模型)", - ) - parser.add_argument( - "--element", "-e", - help="单元名称 (e.g., 'tet4'), required for --task=stiffness", - ) - parser.add_argument( - "--material", "-m", - help="材料名称 (e.g., 'isotropic'), required for --task=constitutive", - ) - parser.add_argument( - "--script", "-s", - help="Python 脚本路径 (用于 --task=custom). 脚本中需要提供 get_model() 函数返回 MathModel.", - ) - parser.add_argument( - "--target", "-t", - required=True, - choices=["jax", "cpp", "cuda", "fortran", "all"], - help="目标语言:jax / cpp / cuda / fortran / all", - ) - parser.add_argument( - "--output", "-o", - default=None, - help="输出文件路径(默认根据任务和名称生成)", - ) - parser.add_argument( - "--chunk-size", - type=int, - default=None, - help="CSE chunk size. 如果省略,则使用 cse-strategy 决定。", - ) - parser.add_argument( - "--cse-strategy", - choices=["auto", "fixed"], - default="auto", - help="CSE chunk sizing 策略。'auto' 根据输出规模自动调整,'fixed' 使用固定默认值。", - ) - parser.add_argument( - "--test", - action="store_true", - default=False, - help="同时生成 CI 测试资产(C++/Fortran wrapper、test_driver.py、build 脚本)", - ) - parser.add_argument( - "--test-output-dir", - default=None, - help="测试资产输出目录(仅在 --test 启用时有效,默认与 --output 相同)", - ) - args = parser.parse_args() - - if args.task == "constitutive": - if not args.material: - parser.error("--material is required for --task=constitutive") - material = load_material(args.material) - model = material.get_constitutive_model() - models_to_compile = {model.name: model} - - elif args.task == "stiffness": - if not args.element: - parser.error("--element is required for --task=stiffness") - element = load_element(args.element) - operators = element.get_stiffness_operators() - if operators: - models_to_compile = {op.name: op for op in operators} - else: - m = element.get_stiffness_model() - models_to_compile = {m.name: m} - - elif args.task == "mass": - if not args.element: - parser.error("--element is required for --task=mass") - element = load_element(args.element) - operators = element.get_mass_operators() - if operators: - models_to_compile = {op.name: op for op in operators} - else: - # Fallback if no specific mass model is defined, though mass usually has operators - parser.error(f"No mass operators defined for element: {args.element}") - - elif args.task == "custom": - if not args.script: - parser.error("--script is required for --task=custom") - script_path = Path(args.script) - if not script_path.exists(): - parser.error(f"Script file not found: {script_path}") - - # Dynamically load the script - spec = importlib.util.spec_from_file_location("custom_script", str(script_path)) - custom_mod = importlib.util.module_from_spec(spec) - sys.modules["custom_script"] = custom_mod - spec.loader.exec_module(custom_mod) - - if not hasattr(custom_mod, "get_model"): - parser.error(f"Script {script_path} must define a 'get_model()' function.") - - models = custom_mod.get_model() - if type(models).__name__ == "MathModel": - models_to_compile = {models.name: models} - elif isinstance(models, list) and all(type(m).__name__ == "MathModel" for m in models): - models_to_compile = {m.name: m for m in models} - else: - parser.error(f"get_model() must return a MathModel or a list of MathModels. Got: {type(models)}") - - # ---------------- Compile Models ---------------- - base_test_dir = Path(args.test_output_dir or args.output or ".") if args.test else None - - for name, model in models_to_compile.items(): - if args.target == "all": - # --target all: 使用 compile_all 实现真正的共享 CSE - generated = FEACompiler.compile_all( - model, - chunk_size=args.chunk_size, - cse_strategy=args.cse_strategy, - test=args.test, - task=args.task, - model_name=args.material or args.element or name, - ) - for t, code in generated.items(): - if t in ("jax", "cpp", "cuda", "fortran"): - out_path = Path(args.output or ".") / _default_output(name, t) - out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - f.write(code) - print(f"Generated: {out_path}") - # Also copy kernel source to test directory if --test is enabled - if args.test and base_test_dir is not None: - kernel_dir = base_test_dir / name - kernel_dir.mkdir(parents=True, exist_ok=True) - kernel_ext_map = {"cpp": "kernel.cpp", "cuda": "kernel.cu", "fortran": "kernel.f90"} - if t in kernel_ext_map: - kernel_path = kernel_dir / kernel_ext_map[t] - with open(kernel_path, "w", encoding="utf-8") as f: - f.write(code) - elif args.test and t in ("cpp_wrapper", "f90_wrapper", "test_driver", "build_sh", "build_bat"): - # Each model gets its own subdirectory to avoid overwriting - test_dir = base_test_dir / name - test_dir.mkdir(parents=True, exist_ok=True) - fname_map = { - "cpp_wrapper": "main.cpp", - "f90_wrapper": "main.f90", - "test_driver": "test_driver.py", - "build_sh": "build.sh", - "build_bat": "build.bat", - } - out_path = test_dir / fname_map[t] - with open(out_path, "w", encoding="utf-8") as f: - f.write(code) - print(f"Generated: {out_path}") - # Generate codegen_meta.json for test_driver.py to locate sympy_codegen - if args.test and base_test_dir is not None: - import json as _json - test_dir = base_test_dir / name - test_dir.mkdir(parents=True, exist_ok=True) - code_gen_dir = Path(__file__).parent.resolve() - rel_path = os.path.relpath(code_gen_dir, test_dir.resolve()) - meta = {"code_gen_rel_path": rel_path} - meta_path = test_dir / "codegen_meta.json" - with open(meta_path, "w", encoding="utf-8") as f: - _json.dump(meta, f, indent=2) - print(f"Generated: {meta_path}") - else: - # 单一目标编译 - code = FEACompiler.compile( - model, - args.target, - chunk_size=args.chunk_size, - cse_strategy=args.cse_strategy, - ) - out_path = Path(args.output or ".") / _default_output(name, args.target) - with open(out_path, "w", encoding="utf-8") as f: - f.write(code) - print(f"Generated: {out_path}") - if __name__ == "__main__": main() From 4faff197c30c935fc1e743097b929d01e820a961 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 22:27:55 +0800 Subject: [PATCH 03/14] =?UTF-8?q?chore:=20=E6=9B=B4=E6=96=B0=20.gitignore?= =?UTF-8?q?=20=E5=BF=BD=E7=95=A5=20repomix=20=E6=96=87=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index a143449..b6cac1d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,5 @@ dist/ build/ .venv/ hex8r_gen/ +repomix-output.txt +repomix.config.json From c10488d3ddd6360ac45bc24afadc355c946fc948 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 22:51:24 +0800 Subject: [PATCH 04/14] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0=20IR=20?= =?UTF-8?q?=E5=B1=82=E5=AE=9E=E7=8E=B0=E8=AE=A1=E5=88=92=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plan.md | 57 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 plan.md diff --git a/plan.md b/plan.md new file mode 100644 index 0000000..a9deac4 --- /dev/null +++ b/plan.md @@ -0,0 +1,57 @@ +准备在这个代码生成器中增加IR层,实现mathmodel之间的调度流程 + +分步实现计划 +按照最小化修改原则,分三步落地: + +```mermaid +graph LR + S1["Step 1
数据模型定义"] --> S2["Step 2
C++ 后端编译"] + S2 --> S3["Step 3
Fortran + JAX 后端"] + + style S1 fill:#e1f5fe + style S2 fill:#fff3e0 + style S3 fill:#e8f5e9 + +``` + +**Step 1: 数据模型定义(仅 codegen/model.py)** +新增 FlowModel, Buffer, Assign, BufferZero, BufferCopy, BufferAccum, Call, If, For。纯数据类,无逻辑。更新 __init__.py 导出。 + +改动文件: + +- codegen/model.py — 新增类定义 + +- codegen/__init__.py — 新增导出 + +**Step 2: C++ 后端编译(codegen/compiler.py + codegen/printer.py)** + 在 FEACompiler 中新增 compile_flow() 方法,先生成子模型的 compute_xxx 函数,再生成主流程函数。 + +核心逻辑: + +编译所有 submodels → 得到 compute_xxx 函数源码 + +1. 遍历 body 生成主流程函数体 +2. 处理 Call → 发出函数调用语句,标量输出用 double var;,缓冲区输出用 double var[N]; +3. 处理 For → 发出 for (int idx = start; idx < end; idx++),unroll=True 时展开 +4. 处理 If → 发出 if (cond) { ... } else { ... } +5. 处理 BufferZero/Copy/Accum → 发出对应语句 + + +改动文件: + +codegen/compiler.py — 新增 _flow_to_source(), compile_flow() +codegen/printer.py — 可能需要辅助方法 + +**Step 3: Fortran + JAX 后端** +Fortran 的 For → do 循环,If → if/else,Call → call sub()。 + +JAX 最特殊:For → jax.lax.fori_loop / jax.lax.scan,If → jax.lax.cond,BufferAccum → buffer.at[i].add()。JAX 后端可以后置,先确保 C++/Fortran 可用。 + +改动文件: + +codegen/compiler.py — 新增 _flow_to_fortran(), _flow_to_jax() + + + +**CLI 支持** +在 cli.py 中扩展 --task flow,从 Python 脚本加载 get_flow_model() 返回 FlowModel。 \ No newline at end of file From b3dc633a10b0d6dbfb3b65fafa9a0fdf26450fe6 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 22:53:01 +0800 Subject: [PATCH 05/14] =?UTF-8?q?feat(codegen):=20=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=B5=81=E7=A8=8B=E5=B1=82=E6=A8=A1=E5=9E=8B=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/__init__.py | 15 ++++++++- codegen/model.py | 79 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 93 insertions(+), 1 deletion(-) diff --git a/codegen/__init__.py b/codegen/__init__.py index 810929d..97afed7 100644 --- a/codegen/__init__.py +++ b/codegen/__init__.py @@ -1,6 +1,10 @@ """codegen 包 — 对外 re-export 所有公开符号,保证 from sympy_codegen import ... 仍然可用。""" -from codegen.model import ExternalOperator, ExternalCall, MathModel, external_call +from codegen.model import ( + ExternalOperator, ExternalCall, MathModel, external_call, + Assign, BufferZero, BufferCopy, BufferAccum, Call, If, For, + Buffer, FlowModel, +) from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter from codegen.printer import FEACodePrinter, FEAFortranPrinter from codegen.compiler import FEACompiler @@ -11,6 +15,15 @@ "ExternalCall", "MathModel", "external_call", + "Assign", + "BufferZero", + "BufferCopy", + "BufferAccum", + "Call", + "If", + "For", + "Buffer", + "FlowModel", "LoweredChunk", "LoweredModel", "CachedPrinter", diff --git a/codegen/model.py b/codegen/model.py index 234eff6..3f6c451 100644 --- a/codegen/model.py +++ b/codegen/model.py @@ -47,3 +47,82 @@ def external_call(model, op_name, input_exprs, n_outputs=None, prefix=None): ) return output_symbols + + +# ═══════════════════════════════════════════════════════════ +# Flow Layer — 流程模型数据类 +# ═══════════════════════════════════════════════════════════ + +# ─── 标量操作 ─────────────────────────────────────────── + +class Assign: + """标量赋值:var = expr""" + def __init__(self, target, expr): + self.target = target # str | sp.Symbol + self.expr = expr # sp.Expr | int | float + +# ─── 缓冲区操作 ───────────────────────────────────────── + +class BufferZero: + """缓冲区清零:buf[:] = 0.0""" + def __init__(self, target): + self.target = target # str (必须在 FlowModel.local_buffers 中声明) + +class BufferCopy: + """缓冲区拷贝:target[:] = source[:]""" + def __init__(self, target, source): + self.target = target # str + self.source = source # str + +class BufferAccum: + """缓冲区累加:target[i] += source[i]""" + def __init__(self, target, source): + self.target = target # str + self.source = source # str + +# ─── 控制流 ────────────────────────────────────────────── + +class Call: + """调用子模型:output_vars[i] = submodel(input_exprs)[i]""" + def __init__(self, model_name, input_exprs, output_vars): + self.model_name = model_name # str, key in FlowModel.submodels + self.input_exprs = list(input_exprs) # list[sp.Expr | str] + self.output_vars = list(output_vars) # list[str] + +class If: + """条件分支""" + def __init__(self, cond, then_body, else_body=None): + self.cond = cond # sp.Expr (Relational) + self.then_body = list(then_body) # list[Statement] + self.else_body = list(else_body or []) # list[Statement] + +class For: + """循环""" + def __init__(self, index, start, end, body, unroll=False): + self.index = index # sp.Symbol + self.start = start # int | sp.Expr + self.end = end # int | sp.Expr + self.body = list(body) # list[Statement] + self.unroll = unroll # True: 生成时展开; False: 生成循环语句 + +# ─── 缓冲区声明 ───────────────────────────────────────── + +class Buffer: + """局部缓冲区声明""" + def __init__(self, name, size, dtype="double"): + self.name = name # str + self.size = size # int (标量元素数,如 24*24=576) + self.dtype = dtype # "double" | "int" | ... + +# ─── 流程模型 ─────────────────────────────────────────── + +class FlowModel: + """流程模型:命令式主程序""" + def __init__(self, name, inputs, outputs, body, + local_buffers=None, submodels=None): + self.name = name # str + self.inputs = inputs # list[sp.Symbol] + self.outputs = list(outputs) # list[str | sp.Symbol] + self.body = list(body) # list[Statement] + self.local_buffers = local_buffers or [] # list[Buffer] + self.submodels = submodels or {} # dict[str, MathModel] From deabc9c5041c093034d37012a7e7d0cdc4f24243 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 22:59:55 +0800 Subject: [PATCH 06/14] =?UTF-8?q?feat(codegen):=20=E6=B7=BB=E5=8A=A0=20Flo?= =?UTF-8?q?wModel=20=E7=BC=96=E8=AF=91=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/compiler.py | 517 +++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 516 insertions(+), 1 deletion(-) diff --git a/codegen/compiler.py b/codegen/compiler.py index 39be9fd..e5b5c2e 100644 --- a/codegen/compiler.py +++ b/codegen/compiler.py @@ -4,7 +4,10 @@ from sympy.core.relational import Relational from sympy.printing.numpy import JaxPrinter -from codegen.model import MathModel +from codegen.model import ( + MathModel, FlowModel, Assign, BufferZero, BufferCopy, BufferAccum, + Call, If, For, Buffer, +) from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter from codegen.printer import FEACodePrinter, FEAFortranPrinter from definitions.abc import Element @@ -592,3 +595,515 @@ def compile_element(element: Element, target: str, chunk_size=None, cse_strategy model = element.get_stiffness_model() return {model.name: FEACompiler.compile(model, target, chunk_size=chunk_size, cse_strategy=cse_strategy)} + + # ========================================================================= + # FlowModel 编译 — 命令式主流程代码生成 + # ========================================================================= + @staticmethod + def compile_flow(flow: FlowModel, target: str, chunk_size=None, cse_strategy="auto"): + """ + 编译 FlowModel,生成主流程函数 + 所有子模型函数的完整源码。 + + Args: + flow: FlowModel 实例 + target: 目标平台 ('cpp', 'cuda', 'fortran', 'jax') + chunk_size: 子模型 CSE chunk size (可选) + cse_strategy: CSE 策略 + + Returns: + str: 完整源码(包含子模型函数 + 主流程函数) + """ + target = target.lower() + if target in ('cpp', 'c++', 'cuda'): + is_cuda = (target == 'cuda') + return FEACompiler._flow_to_source(flow, is_cuda=is_cuda, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'fortran': + return FEACompiler._flow_to_fortran(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) + else: + raise ValueError(f"FlowModel does not support target '{target}' yet") + + # ------------------------------------------------------------------------- + # C++/CUDA Flow 代码生成 + # ------------------------------------------------------------------------- + @staticmethod + def _flow_to_source(flow: FlowModel, is_cuda=False, chunk_size=None, cse_strategy="auto"): + """生成 FlowModel 的 C++/CUDA 源码""" + + # 1. 编译所有子模型 → compute_xxx 函数源码(剥离重复的宏定义) + macros_str = FEACompiler._cpp_cuda_compat_macros() + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = FEACompiler.compile(sub_model, "cuda" if is_cuda else "cpp", + chunk_size=chunk_size, cse_strategy=cse_strategy) + # 剥离子模型源码中的宏定义(避免重复) + sub_src = sub_src.replace(macros_str, "").strip() + sub_sources.append(sub_src) + + # 2. 构建缓冲区查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + + # 3. 收集标量变量(Call 输出中不在 buffer_map 的) + # 需要先遍历 body 收集 + scalar_vars = set() + FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + + # 4. 生成主流程函数 + printer = CachedPrinter(FEACodePrinter()) + + # --- 函数注释 --- + comment_lines = ["/**"] + comment_lines.append(f" * @brief Flow kernel: {flow.name}") + comment_lines.append(" * ") + comment_lines.append(" * @param in Input array (const double*). Layout:") + for i, sym in enumerate(flow.inputs): + comment_lines.append(f" * - in[{i}]: {sym}") + comment_lines.append(" * ") + comment_lines.append(" * @param out Output array (double*). Layout:") + for i, name in enumerate(flow.outputs): + comment_lines.append(f" * - out[{i}]: {name}") + comment_lines.append(" */") + comment_block = "\n".join(comment_lines) + + # --- 函数体 --- + body_lines = [] + + # 解包输入 + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + body_lines.append(f" double {s} = in[{i}];") + body_lines.append("") + + # 声明缓冲区 + for buf in flow.local_buffers: + dtype = "double" if buf.dtype == "double" else buf.dtype + body_lines.append(f" {dtype} {buf.name}[{buf.size}];") + body_lines.append("") + + # 声明标量变量 + if scalar_vars: + for var in sorted(scalar_vars): + body_lines.append(f" double {var};") + body_lines.append("") + + # 生成 body + FEACompiler._emit_body(flow.body, buffer_map, printer, body_lines, indent=1) + + # 输出映射 + body_lines.append("") + body_lines.append(" // --- Output ---") + for i, out_name in enumerate(flow.outputs): + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + for j in range(buf.size): + body_lines.append(f" out[{i * buf.size + j}] = {s}[{j}];") + else: + body_lines.append(f" out[{i}] = {s};") + + body = "\n".join(body_lines) + + # --- 函数签名 --- + prefix_macros = FEACompiler._cpp_cuda_compat_macros() + "\n" + if is_cuda: + func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" + else: + func_type = "FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{flow.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + + main_func = f"{comment_block}\n{signature} {{\n{body}\n}}" + + # 组装:宏定义(1次) + 子模型函数 + 主流程函数 + parts = [prefix_macros] + sub_sources + [main_func] + return "\n\n".join(parts) + + # ------------------------------------------------------------------------- + # Fortran Flow 代码生成 + # ------------------------------------------------------------------------- + @staticmethod + def _flow_to_fortran(flow: FlowModel, chunk_size=None, cse_strategy="auto"): + """生成 FlowModel 的 Fortran 源码""" + + # 1. 编译所有子模型 + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = FEACompiler.compile(sub_model, "fortran", + chunk_size=chunk_size, cse_strategy=cse_strategy) + sub_sources.append(sub_src) + + # 2. 构建查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + scalar_vars = set() + FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + # 收集 For 循环的 index 变量(Fortran 需声明为 integer) + for_indices = set() + FEACompiler._collect_for_indices(flow.body, for_indices) + + # 3. 生成主流程 + printer = CachedPrinter(FEAFortranPrinter()) + + lines = [ + "! Generated by sympy_codegen.py. Do not edit.", + "!", + f"! Flow kernel: compute_{flow.name}", + "!", + f"subroutine compute_{flow.name}(in_vec, out_vec)", + " implicit none", + f" double precision, intent(in) :: in_vec({len(flow.inputs)})", + f" double precision, intent(out) :: out_vec({len(flow.outputs)})", + ] + + # 解包输入 + input_vars = [] + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + input_vars.append(s) + + if input_vars: + lines.append(" ! --- Unpack inputs ---") + lines.append(f" double precision :: {', '.join(input_vars)}") + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_vec({i + 1})") + lines.append("") + + # 声明缓冲区 + if flow.local_buffers: + lines.append(" ! --- Local Buffers ---") + for buf in flow.local_buffers: + lines.append(f" double precision :: {buf.name}({buf.size})") + lines.append("") + + # 声明标量变量 + if scalar_vars: + lines.append(" ! --- Local Scalars ---") + lines.append(f" double precision :: {', '.join(sorted(scalar_vars))}") + lines.append("") + + # 声明 For 循环 index 变量 + if for_indices: + lines.append(" ! --- Loop Indices ---") + lines.append(f" integer :: {', '.join(sorted(for_indices))}") + lines.append("") + + # 生成 body + FEACompiler._emit_body_fortran(flow.body, buffer_map, printer, lines, indent=1) + + # 输出映射 + lines.append("") + lines.append(" ! --- Output ---") + offset = 0 + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + for j in range(buf.size): + lines.append(f" out_vec({offset + j + 1}) = {s}({j + 1})") + offset += buf.size + else: + lines.append(f" out_vec({offset + 1}) = {s}") + offset += 1 + + lines.append(f"end subroutine compute_{flow.name}") + + # 组装 + parts = sub_sources + ["\n".join(lines)] + return "\n\n".join(parts) + + # ========================================================================= + # 辅助方法:标量变量收集 + # ========================================================================= + @staticmethod + def _collect_scalar_vars(body, buffer_map, scalar_vars): + """递归收集 body 中 Call 输出的标量变量名""" + for stmt in body: + if isinstance(stmt, Call): + for var_name in stmt.output_vars: + if var_name not in buffer_map: + scalar_vars.add(var_name) + elif isinstance(stmt, If): + FEACompiler._collect_scalar_vars(stmt.then_body, buffer_map, scalar_vars) + FEACompiler._collect_scalar_vars(stmt.else_body, buffer_map, scalar_vars) + elif isinstance(stmt, For): + FEACompiler._collect_scalar_vars(stmt.body, buffer_map, scalar_vars) + + @staticmethod + def _collect_for_indices(body, indices): + """递归收集 For 语句的 index 变量名""" + for stmt in body: + if isinstance(stmt, For): + indices.add(str(stmt.index)) + FEACompiler._collect_for_indices(stmt.body, indices) + elif isinstance(stmt, If): + FEACompiler._collect_for_indices(stmt.then_body, indices) + FEACompiler._collect_for_indices(stmt.else_body, indices) + + # ========================================================================= + # 辅助方法:C++/CUDA 语句生成 + # ========================================================================= + @staticmethod + def _emit_body(body, buffer_map, printer, lines, indent=1): + """递归生成 C++/CUDA 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = FEACompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str};") + + elif isinstance(stmt, BufferZero): + buf = buffer_map[stmt.target] + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = 0.0;") + + elif isinstance(stmt, BufferCopy): + buf = buffer_map[stmt.target] + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {stmt.source}[_i];") + + elif isinstance(stmt, BufferAccum): + buf = buffer_map[stmt.target] + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] += {stmt.source}[_i];") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + # 计算子模型的总输入/输出数量 + # input_exprs 中可能有标量引用或 SymPy 表达式 + # 当 input_exprs 元素是 str 且在 buffer_map 中时,需要逐元素填充 + lines.append(f"{pad}// Call: {model_name}") + # 计算总输入元素数 + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + lines.append(f"{pad}{{") + lines.append(f"{pad} double _call_in[{total_in}];") + # 填充输入 + in_idx = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + buf = buffer_map[e] + for j in range(buf.size): + lines.append(f"{pad} _call_in[{in_idx}] = {e}[{j}];") + in_idx += 1 + else: + lines.append(f"{pad} _call_in[{in_idx}] = {FEACompiler._print_expr(e, printer)};") + in_idx += 1 + + # 输出:标量用临时变量,缓冲区直接传 + scalar_outs = [] + buf_outs = [] + for var_name in stmt.output_vars: + if var_name in buffer_map: + buf_outs.append(var_name) + else: + scalar_outs.append(var_name) + + if scalar_outs: + lines.append(f"{pad} double _call_out_scalar[{len(scalar_outs)}];") + + # 构建连续输出数组 + if buf_outs: + # 缓冲区直接传递,标量先写进临时数组再取回 + # 需要知道子模型的 output layout + # 简化方案:总是构造完整 _call_out 数组,然后拷贝回去 + sub_model = None # 不引用 flow.submodels 以保持纯静态 + # 用通用方案:构造 _call_out 数组 + total_out = len(stmt.output_vars) + lines.append(f"{pad} double _call_out[{total_out}];") + lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") + + # 把输出分发给各个变量 + for i, var_name in enumerate(stmt.output_vars): + if var_name in buffer_map: + buf = buffer_map[var_name] + for j in range(buf.size): + lines.append(f"{pad} {var_name}[{j}] = _call_out[{i * buf.size + j}];") + else: + lines.append(f"{pad} {var_name} = _call_out[{i}];") + else: + # 全标量输出 + lines.append(f"{pad} double _call_out[{len(scalar_outs)}];") + lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") + for i, var_name in enumerate(scalar_outs): + lines.append(f"{pad} {var_name} = _call_out[{i}];") + + lines.append(f"{pad}}}") + + elif isinstance(stmt, If): + cond_str = FEACompiler._print_expr(stmt.cond, printer) + lines.append(f"{pad}if ({cond_str}) {{") + FEACompiler._emit_body(stmt.then_body, buffer_map, printer, lines, indent + 1) + if stmt.else_body: + lines.append(f"{pad}}} else {{") + FEACompiler._emit_body(stmt.else_body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}}}") + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + # 展开循环 + for i in range(int(start), int(end)): + lines.append(f"{pad}// Unrolled iteration {idx} = {i}") + # 替换 body 中引用 index 的表达式 + sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) + FEACompiler._emit_body(sub_body, buffer_map, printer, lines, indent) + else: + lines.append(f"{pad}for (int {idx} = {start}; {idx} < {end}; {idx}++) {{") + FEACompiler._emit_body(stmt.body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}}}") + + # ========================================================================= + # 辅助方法:Fortran 语句生成 + # ========================================================================= + @staticmethod + def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): + """递归生成 Fortran 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = FEACompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str}") + + elif isinstance(stmt, BufferZero): + buf = buffer_map[stmt.target] + lines.append(f"{pad}{stmt.target}(:) = 0.0d0") + + elif isinstance(stmt, BufferCopy): + lines.append(f"{pad}{stmt.target}(:) = {stmt.source}(:)") + + elif isinstance(stmt, BufferAccum): + buf = buffer_map[stmt.target] + lines.append(f"{pad}{stmt.target}(:) = {stmt.target}(:) + {stmt.source}(:)") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + + # 计算总输入元素数(buffer 展开为逐元素) + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + lines.append(f"{pad}! Call: {model_name}") + lines.append(f"{pad}block") + lines.append(f"{pad} double precision :: _call_in({total_in})") + total_out = len(stmt.output_vars) + lines.append(f"{pad} double precision :: _call_out({total_out})") + + in_idx = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + buf = buffer_map[e] + for j in range(buf.size): + lines.append(f"{pad} _call_in({in_idx + 1}) = {e}({j + 1})") + in_idx += 1 + else: + lines.append(f"{pad} _call_in({in_idx + 1}) = {FEACompiler._print_expr(e, printer)}") + in_idx += 1 + + lines.append(f"{pad} call compute_{model_name}(_call_in, _call_out)") + + # 分发输出 + for i, var_name in enumerate(stmt.output_vars): + if var_name in buffer_map: + buf = buffer_map[var_name] + for j in range(buf.size): + lines.append(f"{pad} {var_name}({j + 1}) = _call_out({i * buf.size + j + 1})") + else: + lines.append(f"{pad} {var_name} = _call_out({i + 1})") + + lines.append(f"{pad}end block") + + elif isinstance(stmt, If): + cond_str = FEACompiler._print_expr(stmt.cond, printer) + lines.append(f"{pad}if ({cond_str}) then") + FEACompiler._emit_body_fortran(stmt.then_body, buffer_map, printer, lines, indent + 1) + if stmt.else_body: + lines.append(f"{pad}else") + FEACompiler._emit_body_fortran(stmt.else_body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}end if") + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + for i in range(int(start), int(end)): + lines.append(f"{pad}! Unrolled iteration {idx} = {i}") + sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) + FEACompiler._emit_body_fortran(sub_body, buffer_map, printer, lines, indent) + else: + lines.append(f"{pad}do {idx} = {start}, {end} - 1") + FEACompiler._emit_body_fortran(stmt.body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}end do") + + # ========================================================================= + # 辅助方法:表达式打印 & 索引替换 + # ========================================================================= + @staticmethod + def _print_expr(expr, printer): + """将 SymPy 表达式或原始值打印为字符串""" + if isinstance(expr, (int, float)): + return str(expr) + if isinstance(expr, str): + return expr + if isinstance(expr, sp.Basic): + return printer.doprint(expr) + return str(expr) + + @staticmethod + def _substitute_index(body, index_sym, value): + """将 body 中所有引用 index_sym 的表达式替换为具体值,返回新的 body 列表""" + import copy + new_body = [] + for stmt in body: + if isinstance(stmt, Assign): + new_expr = stmt.expr + new_target = stmt.target + if isinstance(new_expr, sp.Basic): + new_expr = new_expr.subs(index_sym, value) + if isinstance(new_target, sp.Basic): + new_target = new_target.subs(index_sym, value) + new_body.append(Assign(new_target, new_expr)) + + elif isinstance(stmt, Call): + new_input_exprs = [] + for e in stmt.input_exprs: + if isinstance(e, sp.Basic): + new_input_exprs.append(e.subs(index_sym, value)) + else: + new_input_exprs.append(e) + new_body.append(Call(stmt.model_name, new_input_exprs, stmt.output_vars)) + + elif isinstance(stmt, If): + new_cond = stmt.cond + if isinstance(new_cond, sp.Basic): + new_cond = new_cond.subs(index_sym, value) + new_then = FEACompiler._substitute_index(stmt.then_body, index_sym, value) + new_else = FEACompiler._substitute_index(stmt.else_body, index_sym, value) + new_body.append(If(new_cond, new_then, new_else)) + + elif isinstance(stmt, For): + # 不替换嵌套 For 的 index(不同循环变量),但替换 body 内引用外层 index 的部分 + new_start = stmt.start.subs(index_sym, value) if isinstance(stmt.start, sp.Basic) else stmt.start + new_end = stmt.end.subs(index_sym, value) if isinstance(stmt.end, sp.Basic) else stmt.end + new_body_inner = FEACompiler._substitute_index(stmt.body, index_sym, value) + new_body.append(For(stmt.index, new_start, new_end, new_body_inner, stmt.unroll)) + + else: + # BufferZero, BufferCopy, BufferAccum — 不含表达式,直接拷贝 + new_body.append(stmt) + + return new_body From 87b590061095b7e748b3b7d75cedcaec0098e141 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:06:31 +0800 Subject: [PATCH 07/14] =?UTF-8?q?feat(compiler):=20=E6=B7=BB=E5=8A=A0=20JA?= =?UTF-8?q?X=20=E7=9B=AE=E6=A0=87=E4=BB=A3=E7=A0=81=E7=94=9F=E6=88=90?= =?UTF-8?q?=E6=94=AF=E6=8C=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/compiler.py | 357 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 357 insertions(+) diff --git a/codegen/compiler.py b/codegen/compiler.py index e5b5c2e..e0b0dda 100644 --- a/codegen/compiler.py +++ b/codegen/compiler.py @@ -620,6 +620,8 @@ def compile_flow(flow: FlowModel, target: str, chunk_size=None, cse_strategy="au chunk_size=chunk_size, cse_strategy=cse_strategy) elif target == 'fortran': return FEACompiler._flow_to_fortran(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'jax': + return FEACompiler._flow_to_jax(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) else: raise ValueError(f"FlowModel does not support target '{target}' yet") @@ -1107,3 +1109,358 @@ def _substitute_index(body, index_sym, value): new_body.append(stmt) return new_body + + # ========================================================================= + # JAX Flow 代码生成 + # ========================================================================= + @staticmethod + def _flow_to_jax(flow: FlowModel, chunk_size=None, cse_strategy="auto"): + """ + 生成 FlowModel 的 JAX 源码。 + + JAX 是纯函数式的,生成策略: + - 缓冲区 → jnp.zeros / jnp.array + - BufferAccum → buf = buf.at[i].add(src[i]) 或 buf = buf + src + - For(unroll=False) → jax.lax.fori_loop + - For(unroll=True) → Python 展开循环 + - If → jax.lax.cond (then/else 必须返回相同结构的 tuple) + - Call → 调用子模型函数 + """ + + # 1. 编译所有子模型 + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = FEACompiler.compile(sub_model, "jax", + chunk_size=chunk_size, cse_strategy=cse_strategy) + sub_sources.append(sub_src) + + # 2. 构建查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + scalar_vars = set() + FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + for_indices = set() + FEACompiler._collect_for_indices(flow.body, for_indices) + + # 3. 生成主流程 + from sympy.printing.numpy import JaxPrinter + printer = CachedPrinter(JaxPrinter()) + + lines = [ + '"""Generated by sympy_codegen.py. Do not edit."""', + "import jax", + "import jax.numpy as jnp", + "", + "", + f"def compute_{flow.name}(in_flat):", + f' """', + f' Flow kernel: {flow.name}', + f' ', + f' Args:', + f' in_flat: Flattened input array, size {len(flow.inputs)}', + f' ', + f' Returns:', + f' Flattened output array', + f' """', + ] + + # 解包输入 + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_flat[{i}]") + lines.append("") + + # 初始化缓冲区 + for buf in flow.local_buffers: + if buf.dtype == "double": + lines.append(f" {buf.name} = jnp.zeros({buf.size})") + else: + lines.append(f" {buf.name} = jnp.zeros({buf.size}, dtype=jnp.{buf.dtype})") + lines.append("") + + # 声明标量变量初始值(用于 JAX 的函数式风格) + # JAX 中标量不需要预声明,在赋值时绑定即可 + + # 生成 body + FEACompiler._emit_body_jax(flow.body, buffer_map, printer, lines, indent=1, for_indices=for_indices) + + # 输出映射 + lines.append("") + lines.append(" # --- Output ---") + out_parts = [] + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + out_parts.append(s) + else: + out_parts.append(s) + + if len(out_parts) == 1: + lines.append(f" return {out_parts[0]}") + else: + lines.append(f" return jnp.concatenate([{', '.join(out_parts)}])") + + src = "\n".join(sub_sources) + "\n\n\n" + "\n".join(lines) + src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") + return src + + # ------------------------------------------------------------------------- + # JAX 语句生成 + # ------------------------------------------------------------------------- + @staticmethod + def _emit_body_jax(body, buffer_map, printer, lines, indent=1, for_indices=None): + """递归生成 JAX 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = FEACompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str}") + + elif isinstance(stmt, BufferZero): + buf = buffer_map[stmt.target] + lines.append(f"{pad}{stmt.target} = jnp.zeros({buf.size})") + + elif isinstance(stmt, BufferCopy): + lines.append(f"{pad}{stmt.target} = {stmt.source}.copy()") + + elif isinstance(stmt, BufferAccum): + buf_target = buffer_map[stmt.target] + buf_source = buffer_map[stmt.source] + if buf_target.size == buf_source.size: + # 同尺寸:向量加法 + lines.append(f"{pad}{stmt.target} = {stmt.target} + {stmt.source}") + else: + # 不同尺寸:逐元素 at[].add()(不太常见,但保留安全路径) + lines.append(f"{pad}{stmt.target} = {stmt.target}.at[:{buf_source.size}].add({stmt.source})") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + + # 计算总输入元素数 + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + # 构建输入 + in_parts = [] + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + in_parts.append(e) + else: + in_parts.append(FEACompiler._print_expr(e, printer)) + + lines.append(f"{pad}# Call: {model_name}") + if len(in_parts) == 1: + # 单个输入,可能是标量或数组 + lines.append(f"{pad}_call_in = jnp.array([{in_parts[0]}]) if not isinstance({in_parts[0]}, jnp.ndarray) else {in_parts[0]}.reshape(-1)") + else: + # 多个输入,拼接为 1D 数组 + items = ", ".join(in_parts) + lines.append(f"{pad}_call_in = jnp.concatenate([jnp.atleast_1d(jnp.asarray(x)) for x in [{items}]])") + + lines.append(f"{pad}_call_out = compute_{model_name}(_call_in)") + + # 分发输出 + offset = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + buf = buffer_map[var_name] + lines.append(f"{pad}{var_name} = _call_out[{offset}:{offset + buf.size}]") + offset += buf.size + else: + lines.append(f"{pad}{var_name} = _call_out[{offset}]") + offset += 1 + + elif isinstance(stmt, If): + # JAX: jax.lax.cond(pred, true_fun, false_fun, operand) + # 但生成 jax.lax.cond 的完整函数定义太复杂, + # 如果 body 中只有简单语句,用 jnp.where 更实用 + # 否则用 jax.lax.cond + + # 策略:检测 then_body 和 else_body 的复杂度 + # 如果两者都是纯赋值/BufferAccum,用 jnp.where 系列 + # 如果有 Call/For/嵌套 If,用 jax.lax.cond + + is_simple = FEACompiler._is_simple_jax_if(stmt) + + if is_simple: + # 简单 If:逐语句生成 jnp.where 版本 + FEACompiler._emit_simple_if_jax(stmt, buffer_map, printer, lines, indent) + else: + # 复杂 If:用 jax.lax.cond + FEACompiler._emit_cond_if_jax(stmt, buffer_map, printer, lines, indent) + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + # 展开循环 + for i in range(int(start), int(end)): + lines.append(f"{pad}# Unrolled iteration {idx} = {i}") + sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) + FEACompiler._emit_body_jax(sub_body, buffer_map, printer, lines, indent, for_indices) + else: + # jax.lax.fori_loop + # 需要把循环体封装为一个函数 + # 携带状态 = 所有缓冲区 + 循环体内修改的标量 + carried = FEACompiler._collect_carried_vars(stmt.body, buffer_map) + + if carried: + carry_names = sorted(carried) + carry_tuple = ", ".join(carry_names) + + # 生成循环体函数 + lines.append(f"{pad}def _for_body_{idx}({idx}, _carry):") + for i, name in enumerate(carry_names): + lines.append(f"{pad} {name} = _carry[{i}]") + + # 生成循环体 + FEACompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent + 1, for_indices) + + # 返回 carry + ret_parts = ", ".join(carry_names) + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # 调用 fori_loop + init_parts = ", ".join(carry_names) + if len(carry_names) == 1: + # 单元素 carry: fori_loop 返回 (val,), 需要 [0] 解包 + lines.append(f"{pad}{carry_tuple}, = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") + else: + lines.append(f"{pad}{carry_tuple} = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") + else: + # 无携带状态,循环体无副作用,只需调用一次 + lines.append(f"{pad}# For loop with no carried state — body executed once") + FEACompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent, for_indices) + + # ------------------------------------------------------------------------- + # JAX If 辅助方法 + # ------------------------------------------------------------------------- + @staticmethod + def _is_simple_jax_if(if_stmt): + """判断 If 语句是否足够简单,可以用 jnp.where 实现""" + def _is_simple_body(body): + for s in body: + if isinstance(s, (Call, For, If)): + return False + if isinstance(s, BufferAccum): + return False # BufferAccum 需要 += 语义 + return True + + return _is_simple_body(if_stmt.then_body) and _is_simple_body(if_stmt.else_body) + + @staticmethod + def _emit_simple_if_jax(if_stmt, buffer_map, printer, lines, indent): + """用 jnp.where 生成简单 If""" + pad = " " * indent + cond_str = FEACompiler._print_expr(if_stmt.cond, printer) + + # 先生成 then_body 的赋值 + for s in if_stmt.then_body: + if isinstance(s, Assign): + target = str(s.target) + then_val = FEACompiler._print_expr(s.expr, printer) + # 从 else_body 找同名赋值,或用原值 + else_val = target # 默认:不变 + for es in if_stmt.else_body: + if isinstance(es, Assign) and str(es.target) == target: + else_val = FEACompiler._print_expr(es.expr, printer) + break + if target in buffer_map: + lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") + else: + lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") + + elif isinstance(s, BufferZero): + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.zeros({buffer_map[s.target].size}), {s.target})") + + elif isinstance(s, BufferCopy): + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, {s.source}.copy(), {s.target})") + + # else_body 中独有的赋值 + then_targets = {str(s.target) for s in if_stmt.then_body if isinstance(s, (Assign, BufferZero, BufferCopy))} + for s in if_stmt.else_body: + if isinstance(s, Assign) and str(s.target) not in then_targets: + target = str(s.target) + else_val = FEACompiler._print_expr(s.expr, printer) + lines.append(f"{pad}{target} = jnp.where({cond_str}, {target}, {else_val})") + + @staticmethod + def _emit_cond_if_jax(if_stmt, buffer_map, printer, lines, indent): + """用 jax.lax.cond 生成复杂 If""" + pad = " " * indent + cond_str = FEACompiler._print_expr(if_stmt.cond, printer) + + # 收集 then/else 修改的变量 + then_carried = FEACompiler._collect_carried_vars(if_stmt.then_body, buffer_map) + else_carried = FEACompiler._collect_carried_vars(if_stmt.else_body, buffer_map) + carried = sorted(then_carried | else_carried) + + if not carried: + # 无副作用,直接生成 then_body(条件满足时执行) + # 但 JAX 的 cond 需要两边都有返回值 + # 简化:直接生成 then_body + FEACompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent) + return + + carry_tuple = ", ".join(carried) + + # then 函数 + lines.append(f"{pad}def _if_true(_carry):") + for i, name in enumerate(carried): + lines.append(f"{pad} {name} = _carry[{i}]") + FEACompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent + 1) + ret_parts = ", ".join(carried) + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # else 函数 + lines.append(f"{pad}def _if_false(_carry):") + for i, name in enumerate(carried): + lines.append(f"{pad} {name} = _carry[{i}]") + if if_stmt.else_body: + FEACompiler._emit_body_jax(if_stmt.else_body, buffer_map, printer, lines, indent + 1) + # else 必须和 then 返回相同结构 + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # 调用 jax.lax.cond + init_parts = ", ".join(carried) + if len(carried) == 1: + lines.append(f"{pad}{carry_tuple}, = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") + else: + lines.append(f"{pad}{carry_tuple} = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") + + # ------------------------------------------------------------------------- + # JAX 携带变量收集 + # ------------------------------------------------------------------------- + @staticmethod + def _collect_carried_vars(body, buffer_map): + """收集 body 中被修改的变量名(需要作为 fori_loop/cond 的 carry)""" + carried = set() + for stmt in body: + if isinstance(stmt, Assign): + carried.add(str(stmt.target)) + elif isinstance(stmt, BufferZero): + carried.add(stmt.target) + elif isinstance(stmt, BufferCopy): + carried.add(stmt.target) + elif isinstance(stmt, BufferAccum): + carried.add(stmt.target) + elif isinstance(stmt, Call): + for var_name in stmt.output_vars: + carried.add(var_name) + elif isinstance(stmt, If): + then_c = FEACompiler._collect_carried_vars(stmt.then_body, buffer_map) + else_c = FEACompiler._collect_carried_vars(stmt.else_body, buffer_map) + carried |= then_c | else_c + elif isinstance(stmt, For): + carried |= FEACompiler._collect_carried_vars(stmt.body, buffer_map) + return carried From 6d4d10d8f8cc2989c5442da0f8630ca0487c2742 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:06:52 +0800 Subject: [PATCH 08/14] =?UTF-8?q?docs:=20=E7=A7=BB=E9=99=A4=20IR=20?= =?UTF-8?q?=E5=B1=82=E5=AE=9E=E7=8E=B0=E8=AE=A1=E5=88=92=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- plan.md | 57 --------------------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 plan.md diff --git a/plan.md b/plan.md deleted file mode 100644 index a9deac4..0000000 --- a/plan.md +++ /dev/null @@ -1,57 +0,0 @@ -准备在这个代码生成器中增加IR层,实现mathmodel之间的调度流程 - -分步实现计划 -按照最小化修改原则,分三步落地: - -```mermaid -graph LR - S1["Step 1
数据模型定义"] --> S2["Step 2
C++ 后端编译"] - S2 --> S3["Step 3
Fortran + JAX 后端"] - - style S1 fill:#e1f5fe - style S2 fill:#fff3e0 - style S3 fill:#e8f5e9 - -``` - -**Step 1: 数据模型定义(仅 codegen/model.py)** -新增 FlowModel, Buffer, Assign, BufferZero, BufferCopy, BufferAccum, Call, If, For。纯数据类,无逻辑。更新 __init__.py 导出。 - -改动文件: - -- codegen/model.py — 新增类定义 - -- codegen/__init__.py — 新增导出 - -**Step 2: C++ 后端编译(codegen/compiler.py + codegen/printer.py)** - 在 FEACompiler 中新增 compile_flow() 方法,先生成子模型的 compute_xxx 函数,再生成主流程函数。 - -核心逻辑: - -编译所有 submodels → 得到 compute_xxx 函数源码 - -1. 遍历 body 生成主流程函数体 -2. 处理 Call → 发出函数调用语句,标量输出用 double var;,缓冲区输出用 double var[N]; -3. 处理 For → 发出 for (int idx = start; idx < end; idx++),unroll=True 时展开 -4. 处理 If → 发出 if (cond) { ... } else { ... } -5. 处理 BufferZero/Copy/Accum → 发出对应语句 - - -改动文件: - -codegen/compiler.py — 新增 _flow_to_source(), compile_flow() -codegen/printer.py — 可能需要辅助方法 - -**Step 3: Fortran + JAX 后端** -Fortran 的 For → do 循环,If → if/else,Call → call sub()。 - -JAX 最特殊:For → jax.lax.fori_loop / jax.lax.scan,If → jax.lax.cond,BufferAccum → buffer.at[i].add()。JAX 后端可以后置,先确保 C++/Fortran 可用。 - -改动文件: - -codegen/compiler.py — 新增 _flow_to_fortran(), _flow_to_jax() - - - -**CLI 支持** -在 cli.py 中扩展 --task flow,从 Python 脚本加载 get_flow_model() 返回 FlowModel。 \ No newline at end of file From b608662e1ef60703d29068bca504d1a777a2bce3 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:11:55 +0800 Subject: [PATCH 09/14] feat(cli): add flow task support --- codegen/cli.py | 46 ++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/codegen/cli.py b/codegen/cli.py index 02564d3..e256c39 100644 --- a/codegen/cli.py +++ b/codegen/cli.py @@ -6,6 +6,7 @@ from pathlib import Path from codegen.compiler import FEACompiler +from codegen.model import FlowModel from codegen.loader import load_element, load_material @@ -28,8 +29,8 @@ def main(): parser.add_argument( "--task", required=True, - choices=["constitutive", "stiffness", "mass", "custom"], - help="生成任务: 'constitutive' (材料D矩阵), 'stiffness' (单元Ke矩阵), 'mass' (质量矩阵), 或 'custom' (自定义数学模型)", + choices=["constitutive", "stiffness", "mass", "custom", "flow"], + help="生成任务: 'constitutive' (材料D矩阵), 'stiffness' (单元Ke矩阵), 'mass' (质量矩阵), 'custom' (自定义数学模型), 或 'flow' (流程模型)", ) parser.add_argument( "--element", "-e", @@ -132,6 +133,47 @@ def main(): else: parser.error(f"get_model() must return a MathModel or a list of MathModels. Got: {type(models)}") + elif args.task == "flow": + if not args.script: + parser.error("--script is required for --task=flow") + script_path = Path(args.script) + if not script_path.exists(): + parser.error(f"Script file not found: {script_path}") + + # Dynamically load the script + spec = importlib.util.spec_from_file_location("flow_script", str(script_path)) + flow_mod = importlib.util.module_from_spec(spec) + sys.modules["flow_script"] = flow_mod + spec.loader.exec_module(flow_mod) + + if not hasattr(flow_mod, "get_flow_model"): + parser.error(f"Script {script_path} must define a 'get_flow_model()' function.") + + flow = flow_mod.get_flow_model() + if not isinstance(flow, FlowModel): + parser.error(f"get_flow_model() must return a FlowModel. Got: {type(flow)}") + + # FlowModel 使用单独的编译路径 + target = args.target + if target == "all": + targets = ["cpp", "cuda", "fortran", "jax"] + else: + targets = [target] + + for t in targets: + code = FEACompiler.compile_flow( + flow, t, + chunk_size=args.chunk_size, + cse_strategy=args.cse_strategy, + ) + out_path = Path(args.output or ".") / _default_output(flow.name, t) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + f.write(code) + print(f"Generated: {out_path}") + + return + # ---------------- Compile Models ---------------- base_test_dir = Path(args.test_output_dir or args.output or ".") if args.test else None From 0a1536620f5847fe1c86fdd8312d2a96d5fe48cf Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:20:17 +0800 Subject: [PATCH 10/14] =?UTF-8?q?docs:=20=E6=B7=BB=E5=8A=A0=20FlowModel=20?= =?UTF-8?q?=E6=B5=81=E7=A8=8B=E6=A8=A1=E5=9E=8B=E6=96=87=E6=A1=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- design.md | 148 +++++++++++++++++ manual.md | 485 ++++++++++++++++++++++++++++++++++++------------------ 2 files changed, 474 insertions(+), 159 deletions(-) diff --git a/design.md b/design.md index cfc3b4e..1b869dd 100644 --- a/design.md +++ b/design.md @@ -76,3 +76,151 @@ When external calls are present, the generated code structure becomes: ### 6.4 CSE Interaction External operator outputs are plain SymPy symbols (placeholders), so CSE does not attempt to expand them. Input expressions are printed directly (not included in CSE), keeping the first implementation simple and reliable. + +## 7. Flow Layer — 流程模型 + +### 7.1 动机 + +`MathModel` 是纯函数式的(输入→输出映射),天然适合矩阵计算和算子分解。但有限元中大量逻辑本质上是命令式的: + +- **积分点循环**:遍历积分点,累加贡献到全局矩阵 +- **状态累积**:缓冲区清零→逐点计算→累加→输出 +- **条件分支**:根据材料类型/单元状态选择不同计算路径 + +`FlowModel` 提供命令式主流程抽象,将多个 `MathModel` 子模型编排为完整的计算流程。 + +### 7.2 数据模型 + +#### 7.2.1 FlowModel + +```python +FlowModel(name, inputs, outputs, body, local_buffers=None, submodels=None) +``` + +| 字段 | 类型 | 说明 | +|------|------|------| +| `name` | `str` | 函数名,生成 `compute_{name}` | +| `inputs` | `list[sp.Symbol]` | 输入符号列表 | +| `outputs` | `list[str \| sp.Symbol]` | 输出名称列表 | +| `body` | `list[Statement]` | 流程体(语句列表) | +| `local_buffers` | `list[Buffer]` | 局部缓冲区声明 | +| `submodels` | `dict[str, MathModel]` | 子模型映射 | + +#### 7.2.2 Buffer + +```python +Buffer(name, size, dtype="double") +``` + +局部缓冲区声明。在生成的代码中,缓冲区被声明为定长数组(C++/Fortran)或 `jnp.zeros`(JAX)。 + +#### 7.2.3 语句类层次 + +``` +Statement +├── Assign 标量赋值:var = expr +├── BufferZero 缓冲区清零:buf[:] = 0.0 +├── BufferCopy 缓冲区拷贝:dst[:] = src[:] +├── BufferAccum 缓冲区累加:dst[i] += src[i] +├── Call 调用子模型 +├── If 条件分支 +└── For 循环 +``` + +**语句语义:** + +| 语句 | 参数 | 语义 | +|------|------|------| +| `Assign(target, expr)` | target: str/Symbol, expr: Expr/int/float | 标量赋值 | +| `BufferZero(target)` | target: str(缓冲区名) | 将缓冲区全部置零 | +| `BufferCopy(target, source)` | target/source: str | 逐元素拷贝 | +| `BufferAccum(target, source)` | target/source: str | 逐元素累加 | +| `Call(model_name, input_exprs, output_vars)` | model_name: str, input_exprs: list[Expr/str], output_vars: list[str] | 调用 submodels 中的子模型 | +| `If(cond, then_body, else_body=None)` | cond: Expr, then/else: list[Statement] | 条件分支 | +| `For(index, start, end, body, unroll=False)` | index: Symbol, start/end: int/Expr, body: list[Statement] | 循环,unroll=True 时展开 | + +**Call 的输入输出约定:** +- `input_exprs` 中的 `str` 若在 `buffer_map` 中,视为缓冲区引用(展开为逐元素) +- `output_vars` 中的 `str` 若在 `buffer_map` 中,视为缓冲区引用;否则为标量变量 + +### 7.3 编译管线 + +``` +FlowModel + │ + ├─ FEACompiler.compile_flow(flow, target, chunk_size, cse_strategy) + │ + ├─ target = cpp/cuda ──→ _flow_to_source(flow, is_cuda) + ├─ target = fortran ──→ _flow_to_fortran(flow) + └─ target = jax ──→ _flow_to_jax(flow) +``` + +编译过程: + +1. **编译子模型**:遍历 `flow.submodels`,对每个 `MathModel` 调用 `FEACompiler.compile()` 生成子函数 +2. **构建查找表**:`buffer_map`(缓冲区名→Buffer 对象)、`scalar_vars`(Call 输出中的标量变量)、`for_indices`(循环变量) +3. **生成主流程函数**: + - 解包输入 + - 声明缓冲区 + 标量变量 + - 递归生成 body + - 输出映射 +4. **组装**:子模型函数源码 + 主流程函数 + +### 7.4 各平台代码生成策略 + +#### 7.4.1 C++ / CUDA + +命令式风格,直接映射: + +| 语句 | 生成代码 | +|------|----------| +| `Assign` | `target = expr;` | +| `BufferZero` | `for(int _i=0; _i`: Material name (e.g., `isotropic`). -- `--element `: Element name (e.g., `tet4`, `hex8`). -- `--script `: Path to a Python script for custom tasks. -- `--target {cpp,cuda,jax,peachpy,all}`: **(Required)** Target language. -- `--output `: (Optional) Output file or directory. -- `--test`: (Optional) Generate CI test assets alongside kernel code (C++/Fortran test wrappers, `test_driver.py`, build scripts). -- `--test-output-dir `: (Optional) Directory for test assets (defaults to `--output` if omitted). +```bash +python sympy_codegen.py --task <任务> --target <语言> [选项] +``` + +### 2.1 参数一览 + +| 参数 | 缩写 | 必需 | 说明 | +|------|------|------|------| +| `--task` | | 是 | 生成任务:`constitutive` / `stiffness` / `mass` / `custom` / `flow` | +| `--target` | `-t` | 是 | 目标语言:`cpp` / `cuda` / `fortran` / `jax` / `all` | +| `--material` | `-m` | 条件 | 材料名称(`--task constitutive` 时必需) | +| `--element` | `-e` | 条件 | 单元名称(`--task stiffness|mass` 时必需) | +| `--script` | `-s` | 条件 | Python 脚本路径(`--task custom|flow` 时必需) | +| `--output` | `-o` | 否 | 输出路径,默认自动生成 | +| `--chunk-size` | | 否 | CSE 分块大小,省略则由策略自动决定 | +| `--cse-strategy` | | 否 | `auto`(默认)或 `fixed` | +| `--test` | | 否 | 同时生成 CI 测试资产 | +| `--test-output-dir` | | 否 | 测试资产输出目录,默认与 `--output` 相同 | + +### 2.2 典型用法 + +```bash +# 材料 D 矩阵 → C++ +python sympy_codegen.py --task constitutive --material isotropic --target cpp -## 3. Advanced Features +# 单元刚度算子 → 全平台 +python sympy_codegen.py --task stiffness --element tet4 --target all -### 3.1 Operator-Based Decoupling -For complex elements, a single monolithic kernel can be slow to compile and hard to vectorize. The generator supports splitting the calculation into modular **operators**: -1. **dN_dnat**: Shape function derivatives in natural coordinates. -2. **Mapping**: Jacobian calculation and physical coordinate derivatives (`dN_dx`). -3. **Assembly**: Integration point contribution ($B^T D B \det(J) W$). -4. **Lumped Mass**: Element-level mass distribution for explicit dynamics. +# 自定义模型 → JAX +python sympy_codegen.py --task custom --script my_model.py --target jax -### 3.2 External Operators +# 流程模型 → CUDA +python sympy_codegen.py --task flow --script my_flow.py --target cuda +``` + +### 2.3 输出文件命名 + +| 目标 | 扩展名 | +|------|--------| +| `cpp` | `_gen.cpp` | +| `cuda` | `_gen.cu` | +| `fortran` | `_gen.f90` | +| `jax` | `_gen.py` | + +--- -External operators allow you to skip symbolic expansion of expensive computations (e.g., large matrix inversion) and instead emit function calls to externally implemented routines. +## 3. MathModel — 声明式数学模型 -#### 3.2.1 Define an External Operator +`MathModel` 描述纯函数式的输入→输出映射,适用于矩阵计算、算子分解等场景。 + +### 3.1 构造 ```python from sympy_codegen import MathModel, ExternalOperator, external_call +import sympy as sp +model = MathModel( + inputs=[x, y], # list[sp.Symbol] — 输入符号 + outputs=[x**2 + y**2], # list[sp.Expr] — 输出表达式 + name="my_kernel", # str — 函数名 + input_names=None, # list[str] | None — 输入名称(默认用符号名) + output_names=None, # list[str] | None — 输出名称 + is_operator=False, # bool — 是否作为 SIMD 算子生成 + external_ops=None, # dict[str, ExternalOperator] + external_calls=None, # list[ExternalCall](由 external_call() 自动填充) +) +``` + +### 3.2 外部算子 + +当某些计算(如大矩阵求逆)不适合符号展开时,使用外部算子跳过符号化,在生成代码中插入函数调用。 + +**定义外部算子:** + +```python ext_ops = { "inv12": ExternalOperator( name="inv12", - n_inputs=144, # 12×12 = 144 elements + n_inputs=144, # 12×12 = 144 n_outputs=144, cpp_func="fea_inv12", fortran_func="fea_inv12", - jax_func=None, # Not supported in JAX + jax_func=None, # None 表示 JAX 不支持 ) } ``` -#### 3.2.2 Use in a MathModel +**使用外部算子:** ```python -import sympy as sp - -A_syms = list(sp.symbols("A_0:144", real=True)) - -model = MathModel( - inputs=A_syms, - outputs=[], - name="kernel_with_inv", - external_ops=ext_ops, -) +model = MathModel(inputs=A_syms, outputs=[], name="kernel_inv", external_ops=ext_ops) -# Register the external call — returns 144 placeholder symbols +# 注册调用 → 返回占位符号列表 invA = external_call(model, "inv12", A_syms) -# Use the output symbols in subsequent expressions -B_syms = list(sp.symbols("B_0:144", real=True)) -model.inputs = A_syms + B_syms - -outputs = [] -for i in range(12): - for j in range(12): - val = sum(invA[i*12 + k] * B_syms[k*12 + j] for k in range(12)) - outputs.append(val) -model.outputs = outputs +# 占位符号可直接参与后续表达式 +model.outputs = [invA[i] * B_syms[i] for i in range(144)] ``` -#### 3.2.3 Multiple Calls with Different Prefixes +**多次调用同一算子时使用不同前缀:** ```python invA = external_call(model, "inv12", A_syms, prefix="invA") invB = external_call(model, "inv12", B_syms, prefix="invB") ``` -#### 3.2.4 Generated Code Example (C++) +> **JAX 限制**:若某外部算子的 `jax_func=None`,则该模型无法生成 JAX 代码。 -```cpp -// --- External Operator: inv12 --- -double inv12_in[144]; -double inv12_out[144]; -inv12_in[0] = A_0; -// ... -inv12_in[143] = A_143; -fea_inv12(inv12_in, inv12_out); -double inv12_0 = inv12_out[0]; -// ... -double inv12_143 = inv12_out[143]; +### 3.3 算子分解架构 -// --- Chunk 0 (normal CSE) --- -out[0] = ...; +刚度计算可拆分为解耦算子,每个算子是一个独立的 `MathModel`: + +| 算子 | 功能 | +|------|------| +| `dN_dnat` | 自然坐标下的形函数导数 | +| `mapping` | Jacobian 计算 + 物理坐标导数 | +| `assembly` | 积分点贡献 $B^T D B \det(J) W$ | +| `lumped_mass` | 集中质量分布 | + +在 Element 类中实现: + +```python +def get_stiffness_operators(self) → [op_dN, op_map, op_asm] +def get_mass_operators(self) → [op_mass] ``` -#### 3.2.5 JAX Limitation +--- -If `jax_func` is `None`, generating JAX code for a model that uses that operator will raise a `ValueError`. +## 4. FlowModel — 命令式流程模型 -### 3.3 Fast Validation Solvers (JAX) -Two scripts are provided for rapid verification of generated JAX kernels: -- `static.py`: Solves linear static problems using implicit integration. -- `explicit.py`: Solves dynamic problems using the Central Difference method (explicit). +`FlowModel` 描述命令式主流程,支持缓冲区操作、循环、条件分支和子模型调用。适用于需要迭代、状态累积等命令式逻辑的场景。 -**Example Usage (Explicit):** -```bash -python explicit.py --model test_case/tet4_mat1_ex/tet4_mat1_ex.jsonc --element tet4 --material isotropic +### 4.1 构造 + +```python +from sympy_codegen import ( + FlowModel, Buffer, Assign, BufferZero, BufferCopy, BufferAccum, + Call, If, For, +) +import sympy as sp + +flow = FlowModel( + name="my_flow", # str — 函数名 + inputs=[x, y, z], # list[sp.Symbol] — 输入符号 + outputs=["result", "buf_out"], # list[str | sp.Symbol] + body=[...], # list[Statement] — 流程体 + local_buffers=[Buffer("buf", 24)], # list[Buffer] — 局部缓冲区 + submodels={"sub": math_model}, # dict[str, MathModel] +) ``` -## 4. Performance Optimizations +### 4.2 Buffer — 缓冲区声明 -- **Chunked CSE**: Common Subexpression Elimination is performed in row-level chunks. This reduces generation time for large matrices (like 24x24) from hours to seconds. -- **Memory Alignment**: C++ outputs are structured to be "compiler-friendly" for auto-vectorization. -- **JAX Unpacking**: JAX kernels automatically unpack `in_flat` into named variables for readability and performance. +```python +Buffer(name, size, dtype="double") +``` -## 5. Integration Example (Explicit Operator Mode) +| 参数 | 类型 | 说明 | +|------|------|------| +| `name` | `str` | 缓冲区名称 | +| `size` | `int` | 元素数量(如 24×24 = 576) | +| `dtype` | `str` | 数据类型,默认 `"double"` | -Using decoupled operators in a JAX-based explicit solver: +### 4.3 语句类型 + +#### Assign — 标量赋值 ```python -# 1. dN/dnat (Constant for Tet4) -dN_dnat = kernels["tet4_op_dN_dnat"](jnp.array([0.25, 0.25, 0.25])) +Assign(target, expr) +``` -# 2. Mapping to physical space -map_output = kernels["tet4_op_mapping"](jnp.concatenate([coords_flat, dN_dnat])) -dN_dx, detJ = map_output[0:12], map_output[12] +```python +Assign("alpha", sp.Symbol("x") ** 2) # alpha = x**2 +``` -# 3. Assemble Stiffness -Ke = kernels["tet4_op_assembly"](jnp.concatenate([dN_dx, d_matrix, jnp.array([detJ, 1/6])])) +| 参数 | 类型 | 说明 | +|------|------|------| +| `target` | `str \| sp.Symbol` | 赋值目标 | +| `expr` | `sp.Expr \| int \| float` | 右端表达式 | -# 4. Compute Lumped Mass -Me_lumped = kernels["tet4_op_lumped_mass"](jnp.concatenate([coords_flat, jnp.array([rho])])) +#### BufferZero — 缓冲区清零 + +```python +BufferZero(target) ``` -## 6. CI Testing Workflow +```python +BufferZero("Ke") # Ke[:] = 0.0 +``` -The `--test` flag generates a complete set of test assets that allow cross-backend numerical validation: **SymPy (reference) vs C++ vs Fortran**. +#### BufferCopy — 缓冲区拷贝 -### 6.1 Generate Test Assets +```python +BufferCopy(target, source) +``` -Add `--test` to any code generation command. Use `--test-output-dir` to specify where test files go (defaults to `--output`). +```python +BufferCopy("tmp", "Ke") # tmp[:] = Ke[:] +``` -**Example — Constitutive model (isotropic D-matrix):** -```bash -python sympy_codegen.py --task constitutive --material isotropic --target all --output generated/isotropic_D --test --test-output-dir generated/isotropic_D +#### BufferAccum — 缓冲区累加 + +```python +BufferAccum(target, source) ``` -**Example — Stiffness operators (tet4 element):** -```bash -python sympy_codegen.py --task stiffness --element tet4 --target all --output generated/tet4 --test --test-output-dir generated/tet4 +```python +BufferAccum("Ke", "Ke_ip") # Ke[i] += Ke_ip[i] ``` -This produces, for each model/operator, a subdirectory containing: +> 两缓冲区尺寸相同时生成向量加法,否则生成逐元素累加。 + +#### Call — 调用子模型 + +```python +Call(model_name, input_exprs, output_vars) +``` -| File | Description | -|------|-------------| -| `kernel.cpp` | Generated C++ kernel | -| `kernel.f90` | Generated Fortran kernel | -| `main.cpp` | C++ test wrapper (reads stdin, writes stdout) | -| `main.f90` | Fortran test wrapper | -| `test_driver.py` | Python test driver (SymPy reference + subprocess comparison) | -| `build.sh` | Linux/macOS build script | -| `build.bat` | Windows build script | +```python +Call("constitutive", [E, nu], ["D_matrix"]) +Call("assembly", ["dN_dx", "D_matrix", detJ_w], ["Ke_ip"]) +``` -### 6.2 Build +| 参数 | 类型 | 说明 | +|------|------|------| +| `model_name` | `str` | `submodels` 中的键名 | +| `input_exprs` | `list[sp.Expr \| str]` | 输入表达式;`str` 且在缓冲区中则为缓冲区引用 | +| `output_vars` | `list[str]` | 输出变量名;在缓冲区中则为缓冲区引用,否则为标量 | -Enter the test directory and run the build script: +#### If — 条件分支 -**Linux / macOS:** -```bash -cd generated/tet4/tet4_op_dN_dnat -bash build.sh +```python +If(cond, then_body, else_body=None) ``` -**Windows:** -```cmd -cd generated\tet4\tet4_op_dN_dnat -build.bat +```python +If( + sp.Symbol("flag") > 0, + [Assign("mode", 1)], + [Assign("mode", 0)], +) ``` -The build script auto-detects available compilers: -- C++: tries `clang++`, `g++`, or MSVC `cl` (in that order) -- Fortran: uses `gfortran` +| 参数 | 类型 | 说明 | +|------|------|------| +| `cond` | `sp.Expr` | 条件表达式(Relational) | +| `then_body` | `list[Statement]` | 条件为真时的语句列表 | +| `else_body` | `list[Statement] \| None` | 条件为假时的语句列表 | -Compiler flags: `-O2 -fno-fast-math` (ensures IEEE 754 compliance for reproducible floating-point behavior). +#### For — 循环 -Output executables: `kernel_cpp.exe` (or `kernel_cpp` on Linux) and `kernel_f90.exe` (or `kernel_f90`). +```python +For(index, start, end, body, unroll=False) +``` -### 6.3 Run Tests +```python +i = sp.Symbol("i") +For(i, 0, 8, [ + Call("ip_kernel", ["coords", i], ["Ke_ip"]), + BufferAccum("Ke", "Ke_ip"), +]) +``` -The `test_driver.py` generates random inputs, computes the reference result via SymPy lambdify, then feeds the same input to the compiled C++/Fortran executables and compares outputs. +| 参数 | 类型 | 说明 | +|------|------|------| +| `index` | `sp.Symbol` | 循环变量 | +| `start` | `int \| sp.Expr` | 起始值(含) | +| `end` | `int \| sp.Expr` | 终止值(不含) | +| `body` | `list[Statement]` | 循环体 | +| `unroll` | `bool` | `True` 时展开循环,`False` 时生成循环语句 | -**Basic usage (both backends):** -```bash -python test_driver.py --cpp-exe kernel_cpp.exe --f90-exe kernel_f90.exe +### 4.4 完整示例 + +```python +import sympy as sp +from sympy_codegen import ( + FlowModel, MathModel, Buffer, Assign, BufferZero, BufferAccum, + Call, For, +) + +# 子模型:单积分点贡献 +ip_model = MathModel( + inputs=list(sp.symbols("dN_dx:12 detJ_w:2")), + outputs=list(sp.symbols("Ke_ip:144")), + name="ip_contribution", +) + +# 流程:8 积分点组装 +i = sp.Symbol("i") +flow = FlowModel( + name="assemble_Ke", + inputs=list(sp.symbols("coords:24 D:36")), + outputs=["Ke"], + local_buffers=[ + Buffer("Ke", 144), + Buffer("Ke_ip", 144), + ], + submodels={"ip_contribution": ip_model}, + body=[ + BufferZero("Ke"), + For(i, 0, 8, [ + Call("ip_contribution", ["coords", "D", i], ["Ke_ip"]), + BufferAccum("Ke", "Ke_ip"), + ]), + ], +) ``` -**C++ only:** -```bash -python test_driver.py --cpp-exe kernel_cpp.exe +将上述保存为 `my_flow.py` 并定义 `get_flow_model()` 函数: + +```python +def get_flow_model(): + return flow ``` -**Custom tolerance and run count:** +然后生成代码: + ```bash -python test_driver.py --cpp-exe kernel_cpp.exe --f90-exe kernel_f90.exe --n-runs 5000 --atol 1e-9 --rtol 1e-10 +python sympy_codegen.py --task flow --script my_flow.py --target all ``` -### 6.4 Test Driver CLI Arguments +### 4.5 各目标平台的语义映射 + +| FlowModel 语句 | C++ / CUDA | Fortran | JAX | +|----------------|------------|---------|-----| +| `Assign` | `var = expr;` | `var = expr` | `var = expr` | +| `BufferZero` | `for(_i) buf[_i]=0.0;` | `buf(:) = 0.0d0` | `buf = jnp.zeros(n)` | +| `BufferCopy` | `for(_i) dst[_i]=src[_i];` | `dst(:) = src(:)` | `dst = src.copy()` | +| `BufferAccum` | `for(_i) dst[_i]+=src[_i];` | `dst(:) = dst(:) + src(:)` | `dst = dst + src` | +| `Call` | `compute_xxx(in, out);` | `call compute_xxx(in, out)` | `out = compute_xxx(in)` | +| `If`(简单) | `if/else` | `if/else/end if` | `jnp.where(cond, a, b)` | +| `If`(复杂) | `if/else` | `if/else/end if` | `jax.lax.cond(...)` | +| `For`(unroll) | 展开循环体 | 展开循环体 | 展开循环体 | +| `For`(普通) | `for(int i=...; i<...; i++)` | `do i = ..., ... - 1` | `jax.lax.fori_loop(...)` | -| Argument | Default | Description | -|----------|---------|-------------| -| `--n-runs` | 1000 | Number of random test cases | -| `--atol` | 1e-10 | Absolute tolerance for `np.allclose` | -| `--rtol` | 1e-11 | Relative tolerance for `np.allclose` | -| `--seed` | 42 | Random seed for reproducibility | -| `--cpp-exe` | (none) | Path to compiled C++ executable | -| `--f90-exe` | (none) | Path to compiled Fortran executable | -| `--input-range` | 0.1 2.0 | Range for uniform random inputs | +> **JAX 简单 If**:当 `then_body` 和 `else_body` 中仅含 `Assign` / `BufferZero` / `BufferCopy` 时,使用 `jnp.where`;包含 `Call` / `For` / `If` / `BufferAccum` 时,使用 `jax.lax.cond`。 -At least one of `--cpp-exe` or `--f90-exe` must be provided. +--- -### 6.5 Interpreting Results +## 5. CI 测试 -On success: +### 5.1 生成测试资产 + +添加 `--test` 标志即可在生成内核的同时生成完整的交叉验证测试套件: + +```bash +python sympy_codegen.py --task stiffness --element tet4 --target all --test --output generated/tet4 ``` -Results: 1000/1000 passed, 0 failed (atol=1e-10, rtol=1e-11) + +生成的测试目录结构: + +| 文件 | 说明 | +|------|------| +| `kernel.cpp` | 生成的 C++ 内核 | +| `kernel.f90` | 生成的 Fortran 内核 | +| `main.cpp` | C++ 测试封装 | +| `main.f90` | Fortran 测试封装 | +| `test_driver.py` | Python 测试驱动(SymPy 参考值 + 子进程对比) | +| `build.sh` / `build.bat` | 构建脚本 | + +### 5.2 构建与运行 + +```bash +cd generated/tet4/tet4_op_dN_dnat +bash build.sh # Linux/macOS +build.bat # Windows + +python test_driver.py --cpp-exe kernel_cpp.exe --f90-exe kernel_f90.exe ``` -On failure, a detailed debug dump is printed showing each output value from all backends and the maximum difference, making it easy to identify which output diverged and by how much. +### 5.3 测试驱动参数 + +| 参数 | 默认值 | 说明 | +|------|--------|------| +| `--n-runs` | 1000 | 随机测试用例数 | +| `--atol` | 1e-10 | 绝对容差 | +| `--rtol` | 1e-11 | 相对容差 | +| `--seed` | 42 | 随机种子 | +| `--cpp-exe` | — | C++ 可执行文件路径 | +| `--f90-exe` | — | Fortran 可执行文件路径 | +| `--input-range` | 0.1 2.0 | 均匀随机输入范围 | + +--- + +## 6. JAX 快速验证求解器 + +| 脚本 | 用途 | +|------|------| +| `static.py` | 线性静力求解(隐式积分) | +| `explicit.py` | 动力学显式求解(中心差分法) | + +```bash +python explicit.py --model test_case/tet4_mat1_ex/tet4_mat1_ex.jsonc --element tet4 --material isotropic +``` -### 6.6 Tolerance Guidelines +--- -- **atol=1e-10, rtol=1e-11** (defaults): Suitable for most operators. The `rtol` accounts for large dynamic range computations (e.g., matrix inversion producing values ~500, where absolute error ~1e-9 is within expected floating-point reordering precision). -- For operators with very small output magnitudes, you may need to relax `atol`. For very large magnitudes, `rtol` is the primary control. +## 7. 性能优化 +- **行级分块 CSE**:按行分块执行公共子表达式消除,将大规模矩阵的生成时间从小时级降至秒级 +- **内存对齐**:C++ 输出采用连续布局,便于编译器自动向量化 +- **JAX 解包**:自动将 `in_flat` 解包为命名变量 From f0d744165a2f09c5ac9136c254ba18c661202bdc Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:34:04 +0800 Subject: [PATCH 11/14] =?UTF-8?q?refactor(codegen):=20=E7=BB=9F=E4=B8=80?= =?UTF-8?q?=E7=BC=93=E5=86=B2=E5=8C=BA=E8=BE=93=E5=87=BA=E5=81=8F=E7=A7=BB?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E4=B8=8E=E6=94=AF=E6=8C=81=E9=80=9A=E7=94=A8?= =?UTF-8?q?=E5=A1=AB=E5=85=85?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/__init__.py | 6 +- codegen/compiler.py | 138 +++++++++++++++++++++++++------------------- codegen/model.py | 33 ++++++++--- 3 files changed, 108 insertions(+), 69 deletions(-) diff --git a/codegen/__init__.py b/codegen/__init__.py index 97afed7..05551c4 100644 --- a/codegen/__init__.py +++ b/codegen/__init__.py @@ -2,8 +2,8 @@ from codegen.model import ( ExternalOperator, ExternalCall, MathModel, external_call, - Assign, BufferZero, BufferCopy, BufferAccum, Call, If, For, - Buffer, FlowModel, + Assign, BufferFill, BufferZero, BufferCopy, BufferAccum, Call, If, For, + BufferRef, Buffer, FlowModel, ) from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter from codegen.printer import FEACodePrinter, FEAFortranPrinter @@ -16,12 +16,14 @@ "MathModel", "external_call", "Assign", + "BufferFill", "BufferZero", "BufferCopy", "BufferAccum", "Call", "If", "For", + "BufferRef", "Buffer", "FlowModel", "LoweredChunk", diff --git a/codegen/compiler.py b/codegen/compiler.py index e0b0dda..f5e3df4 100644 --- a/codegen/compiler.py +++ b/codegen/compiler.py @@ -5,8 +5,8 @@ from sympy.printing.numpy import JaxPrinter from codegen.model import ( - MathModel, FlowModel, Assign, BufferZero, BufferCopy, BufferAccum, - Call, If, For, Buffer, + MathModel, FlowModel, Assign, BufferFill, BufferCopy, BufferAccum, + Call, If, For, BufferRef, ) from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter from codegen.printer import FEACodePrinter, FEAFortranPrinter @@ -662,8 +662,16 @@ def _flow_to_source(flow: FlowModel, is_cuda=False, chunk_size=None, cse_strateg comment_lines.append(f" * - in[{i}]: {sym}") comment_lines.append(" * ") comment_lines.append(" * @param out Output array (double*). Layout:") - for i, name in enumerate(flow.outputs): - comment_lines.append(f" * - out[{i}]: {name}") + offset = 0 + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + comment_lines.append(f" * - out[{offset}..{offset + buf.size - 1}]: {s} (buffer, size={buf.size})") + offset += buf.size + else: + comment_lines.append(f" * - out[{offset}]: {s}") + offset += 1 comment_lines.append(" */") comment_block = "\n".join(comment_lines) @@ -692,17 +700,20 @@ def _flow_to_source(flow: FlowModel, is_cuda=False, chunk_size=None, cse_strateg # 生成 body FEACompiler._emit_body(flow.body, buffer_map, printer, body_lines, indent=1) - # 输出映射 + # 输出映射(按累积偏移) body_lines.append("") body_lines.append(" // --- Output ---") - for i, out_name in enumerate(flow.outputs): + offset = 0 + for out_name in flow.outputs: s = str(out_name) if s in buffer_map: buf = buffer_map[s] for j in range(buf.size): - body_lines.append(f" out[{i * buf.size + j}] = {s}[{j}];") + body_lines.append(f" out[{offset + j}] = {s}[{j}];") + offset += buf.size else: - body_lines.append(f" out[{i}] = {s};") + body_lines.append(f" out[{offset}] = {s};") + offset += 1 body = "\n".join(body_lines) @@ -753,7 +764,7 @@ def _flow_to_fortran(flow: FlowModel, chunk_size=None, cse_strategy="auto"): f"subroutine compute_{flow.name}(in_vec, out_vec)", " implicit none", f" double precision, intent(in) :: in_vec({len(flow.inputs)})", - f" double precision, intent(out) :: out_vec({len(flow.outputs)})", + f" double precision, intent(out) :: out_vec({sum(buffer_map[str(o)].size if str(o) in buffer_map else 1 for o in flow.outputs)})", ] # 解包输入 @@ -856,9 +867,10 @@ def _emit_body(body, buffer_map, printer, lines, indent=1): expr_str = FEACompiler._print_expr(stmt.expr, printer) lines.append(f"{pad}{target} = {expr_str};") - elif isinstance(stmt, BufferZero): + elif isinstance(stmt, BufferFill): buf = buffer_map[stmt.target] - lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = 0.0;") + fill_val = stmt.value + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {fill_val};") elif isinstance(stmt, BufferCopy): buf = buffer_map[stmt.target] @@ -882,6 +894,14 @@ def _emit_body(body, buffer_map, printer, lines, indent=1): else: total_in += 1 + # 计算总输出元素数(buffer 展开为逐元素) + total_out = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + total_out += buffer_map[var_name].size + else: + total_out += 1 + lines.append(f"{pad}{{") lines.append(f"{pad} double _call_in[{total_in}];") # 填充输入 @@ -896,43 +916,21 @@ def _emit_body(body, buffer_map, printer, lines, indent=1): lines.append(f"{pad} _call_in[{in_idx}] = {FEACompiler._print_expr(e, printer)};") in_idx += 1 - # 输出:标量用临时变量,缓冲区直接传 - scalar_outs = [] - buf_outs = [] + # 输出:构造统一的 _call_out 数组,然后按累积偏移分发 + lines.append(f"{pad} double _call_out[{total_out}];") + lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") + + # 按累积偏移分发输出 + offset = 0 for var_name in stmt.output_vars: if var_name in buffer_map: - buf_outs.append(var_name) + buf = buffer_map[var_name] + for j in range(buf.size): + lines.append(f"{pad} {var_name}[{j}] = _call_out[{offset + j}];") + offset += buf.size else: - scalar_outs.append(var_name) - - if scalar_outs: - lines.append(f"{pad} double _call_out_scalar[{len(scalar_outs)}];") - - # 构建连续输出数组 - if buf_outs: - # 缓冲区直接传递,标量先写进临时数组再取回 - # 需要知道子模型的 output layout - # 简化方案:总是构造完整 _call_out 数组,然后拷贝回去 - sub_model = None # 不引用 flow.submodels 以保持纯静态 - # 用通用方案:构造 _call_out 数组 - total_out = len(stmt.output_vars) - lines.append(f"{pad} double _call_out[{total_out}];") - lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") - - # 把输出分发给各个变量 - for i, var_name in enumerate(stmt.output_vars): - if var_name in buffer_map: - buf = buffer_map[var_name] - for j in range(buf.size): - lines.append(f"{pad} {var_name}[{j}] = _call_out[{i * buf.size + j}];") - else: - lines.append(f"{pad} {var_name} = _call_out[{i}];") - else: - # 全标量输出 - lines.append(f"{pad} double _call_out[{len(scalar_outs)}];") - lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") - for i, var_name in enumerate(scalar_outs): - lines.append(f"{pad} {var_name} = _call_out[{i}];") + lines.append(f"{pad} {var_name} = _call_out[{offset}];") + offset += 1 lines.append(f"{pad}}}") @@ -975,9 +973,13 @@ def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): expr_str = FEACompiler._print_expr(stmt.expr, printer) lines.append(f"{pad}{target} = {expr_str}") - elif isinstance(stmt, BufferZero): + elif isinstance(stmt, BufferFill): buf = buffer_map[stmt.target] - lines.append(f"{pad}{stmt.target}(:) = 0.0d0") + fill_val = stmt.value + if fill_val == 0.0: + lines.append(f"{pad}{stmt.target}(:) = 0.0d0") + else: + lines.append(f"{pad}{stmt.target}(:) = {fill_val}d0") elif isinstance(stmt, BufferCopy): lines.append(f"{pad}{stmt.target}(:) = {stmt.source}(:)") @@ -1000,7 +1002,13 @@ def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): lines.append(f"{pad}! Call: {model_name}") lines.append(f"{pad}block") lines.append(f"{pad} double precision :: _call_in({total_in})") - total_out = len(stmt.output_vars) + # 计算总输出元素数(buffer 展开为逐元素) + total_out = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + total_out += buffer_map[var_name].size + else: + total_out += 1 lines.append(f"{pad} double precision :: _call_out({total_out})") in_idx = 0 @@ -1016,14 +1024,17 @@ def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): lines.append(f"{pad} call compute_{model_name}(_call_in, _call_out)") - # 分发输出 - for i, var_name in enumerate(stmt.output_vars): + # 按累积偏移分发输出 + offset = 0 + for var_name in stmt.output_vars: if var_name in buffer_map: buf = buffer_map[var_name] for j in range(buf.size): - lines.append(f"{pad} {var_name}({j + 1}) = _call_out({i * buf.size + j + 1})") + lines.append(f"{pad} {var_name}({j + 1}) = _call_out({offset + j + 1})") + offset += buf.size else: - lines.append(f"{pad} {var_name} = _call_out({i + 1})") + lines.append(f"{pad} {var_name} = _call_out({offset + 1})") + offset += 1 lines.append(f"{pad}end block") @@ -1105,7 +1116,7 @@ def _substitute_index(body, index_sym, value): new_body.append(For(stmt.index, new_start, new_end, new_body_inner, stmt.unroll)) else: - # BufferZero, BufferCopy, BufferAccum — 不含表达式,直接拷贝 + # BufferFill, BufferCopy, BufferAccum — 不含表达式,直接拷贝 new_body.append(stmt) return new_body @@ -1217,9 +1228,13 @@ def _emit_body_jax(body, buffer_map, printer, lines, indent=1, for_indices=None) expr_str = FEACompiler._print_expr(stmt.expr, printer) lines.append(f"{pad}{target} = {expr_str}") - elif isinstance(stmt, BufferZero): + elif isinstance(stmt, BufferFill): buf = buffer_map[stmt.target] - lines.append(f"{pad}{stmt.target} = jnp.zeros({buf.size})") + fill_val = stmt.value + if fill_val == 0.0: + lines.append(f"{pad}{stmt.target} = jnp.zeros({buf.size})") + else: + lines.append(f"{pad}{stmt.target} = jnp.full({buf.size}, {fill_val})") elif isinstance(stmt, BufferCopy): lines.append(f"{pad}{stmt.target} = {stmt.source}.copy()") @@ -1378,14 +1393,19 @@ def _emit_simple_if_jax(if_stmt, buffer_map, printer, lines, indent): else: lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") - elif isinstance(s, BufferZero): - lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.zeros({buffer_map[s.target].size}), {s.target})") + elif isinstance(s, BufferFill): + buf = buffer_map[s.target] + fill_val = s.value + if fill_val == 0.0: + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.zeros({buf.size}), {s.target})") + else: + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.full({buf.size}, {fill_val}), {s.target})") elif isinstance(s, BufferCopy): lines.append(f"{pad}{s.target} = jnp.where({cond_str}, {s.source}.copy(), {s.target})") # else_body 中独有的赋值 - then_targets = {str(s.target) for s in if_stmt.then_body if isinstance(s, (Assign, BufferZero, BufferCopy))} + then_targets = {str(s.target) for s in if_stmt.then_body if isinstance(s, (Assign, BufferFill, BufferCopy))} for s in if_stmt.else_body: if isinstance(s, Assign) and str(s.target) not in then_targets: target = str(s.target) @@ -1448,7 +1468,7 @@ def _collect_carried_vars(body, buffer_map): for stmt in body: if isinstance(stmt, Assign): carried.add(str(stmt.target)) - elif isinstance(stmt, BufferZero): + elif isinstance(stmt, BufferFill): carried.add(stmt.target) elif isinstance(stmt, BufferCopy): carried.add(stmt.target) diff --git a/codegen/model.py b/codegen/model.py index 3f6c451..cde2620 100644 --- a/codegen/model.py +++ b/codegen/model.py @@ -63,10 +63,14 @@ def __init__(self, target, expr): # ─── 缓冲区操作 ───────────────────────────────────────── -class BufferZero: - """缓冲区清零:buf[:] = 0.0""" - def __init__(self, target): +class BufferFill: + """缓冲区填充:buf[:] = value""" + def __init__(self, target, value=0.0): self.target = target # str (必须在 FlowModel.local_buffers 中声明) + self.value = value # float, 填充值 (默认 0.0) + +# 兼容别名 +BufferZero = BufferFill class BufferCopy: """缓冲区拷贝:target[:] = source[:]""" @@ -107,12 +111,25 @@ def __init__(self, index, start, end, body, unroll=False): # ─── 缓冲区声明 ───────────────────────────────────────── -class Buffer: +class BufferRef: """局部缓冲区声明""" - def __init__(self, name, size, dtype="double"): - self.name = name # str - self.size = size # int (标量元素数,如 24*24=576) - self.dtype = dtype # "double" | "int" | ... + def __init__(self, name, shape, dtype="double", layout="flat"): + self.name = name # str + self.shape = shape if isinstance(shape, tuple) else (shape,) + # tuple[int, ...] | int → 统一转为 tuple + self.dtype = dtype # "double" | "int" | ... + self.layout = layout # "flat" (默认, 生成代码仍用一维 buffer) + + @property + def size(self): + """缓冲区标量元素总数""" + result = 1 + for dim in self.shape: + result *= dim + return result + +# 兼容别名 +Buffer = BufferRef # ─── 流程模型 ─────────────────────────────────────────── From 36bdc52973ead5bab16a6143c5b7df52bb447a39 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Fri, 24 Apr 2026 23:50:21 +0800 Subject: [PATCH 12/14] =?UTF-8?q?refactor(codegen):=20=E6=8B=86=E5=88=86?= =?UTF-8?q?=20FEACompiler=20=E4=B8=BA=20MathCompiler=20=E5=92=8C=20FlowCom?= =?UTF-8?q?piler?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/compiler.py | 1489 +------------------------------------- codegen/flow_compiler.py | 841 +++++++++++++++++++++ codegen/math_compiler.py | 656 +++++++++++++++++ 3 files changed, 1503 insertions(+), 1483 deletions(-) create mode 100644 codegen/flow_compiler.py create mode 100644 codegen/math_compiler.py diff --git a/codegen/compiler.py b/codegen/compiler.py index f5e3df4..bde2f09 100644 --- a/codegen/compiler.py +++ b/codegen/compiler.py @@ -1,1486 +1,9 @@ -import re +"""compiler.py — 兼容门面,re-export MathCompiler + FlowCompiler,保持 FEACompiler API 不变。""" -import sympy as sp -from sympy.core.relational import Relational -from sympy.printing.numpy import JaxPrinter +from codegen.math_compiler import MathCompiler +from codegen.flow_compiler import FlowCompiler -from codegen.model import ( - MathModel, FlowModel, Assign, BufferFill, BufferCopy, BufferAccum, - Call, If, For, BufferRef, -) -from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter -from codegen.printer import FEACodePrinter, FEAFortranPrinter -from definitions.abc import Element - -class FEACompiler: - # ========================================================================= - # 公共 Lower 阶段:将 MathModel 转换为 LoweredModel,执行 CSE - # ========================================================================= - @staticmethod - def lower_model(model: MathModel, chunk_size: int) -> LoweredModel: - """执行 CSE lowering,返回可被多个后端共享的 LoweredModel""" - outputs = model.outputs - chunks = [] - - for start in range(0, len(outputs), chunk_size): - chunk_index = start // chunk_size - chunk = outputs[start:start + chunk_size] - sub_exprs, simplified_chunk = sp.cse( - chunk, - symbols=sp.numbered_symbols(f"v_{chunk_index}_") - ) - chunks.append( - LoweredChunk( - chunk_index=chunk_index, - start_index=start, - sub_exprs=sub_exprs, - simplified_outputs=simplified_chunk - ) - ) - - return LoweredModel(model.name, chunk_size, chunks, - external_calls=model.external_calls, - external_ops=model.external_ops) - - # ========================================================================= - # Chunk Size 策略:根据模型规模和目标平台决定 chunk size - # ========================================================================= - @staticmethod - def resolve_chunk_size(model: MathModel, target: str, user_chunk_size=None, strategy="auto") -> int: - """ - 决定 CSE chunk size 的策略。 - - Args: - model: 数学模型 - target: 目标平台 (jax/cpp/cuda/fortran等) - user_chunk_size: 用户通过 CLI 指定的 chunk size - strategy: 策略模式 ("auto" 或 "fixed") - - Returns: - 最终的 chunk size - """ - if user_chunk_size is not None: - return user_chunk_size - - nout = len(model.outputs) - target = target.lower() - - # fixed 模式:使用各后端的固定默认值 - if strategy == "fixed": - if target == "jax": - return 50 - if target in ("cpp", "c++", "cuda", "fortran"): - return 24 - return 24 - - # auto 模式:根据输出规模自动调整 - if strategy == "auto": - if target == "jax": - if nout <= 64: - return 64 - elif nout <= 256: - return 48 - else: - return 32 - - # cpp/cuda/fortran 的自适应策略 - if nout <= 32: - return 32 - elif nout <= 128: - return 24 - elif nout <= 512: - return 16 - else: - return 8 - - raise ValueError(f"Unknown strategy: {strategy}") - - # ========================================================================= - # C++/CUDA 兼容性宏:跨平台支持 GCC/Clang/MSVC/CUDA - # ========================================================================= - @staticmethod - def _cpp_cuda_compat_macros() -> str: - """返回统一的 C++/CUDA 跨平台兼容性宏定义""" - return r""" -#if defined(__CUDACC__) - #define FEA_DEVICE __device__ - #define FEA_HOST __host__ - #define FEA_HOST_DEVICE __host__ __device__ - #define FEA_RESTRICT __restrict__ -#else - #define FEA_DEVICE - #define FEA_HOST - #define FEA_HOST_DEVICE - #if defined(_WIN32) || defined(_WIN64) - #if defined(_MSC_VER) - #define FEA_RESTRICT __restrict - #else - #define FEA_RESTRICT __restrict__ - #endif - #else - #if defined(__GNUC__) || defined(__clang__) - #define FEA_RESTRICT __restrict__ - #else - #define FEA_RESTRICT - #endif - #endif -#endif - -#if defined(_MSC_VER) - #define FEA_ALWAYS_INLINE __forceinline -#elif defined(__GNUC__) || defined(__clang__) - #define FEA_ALWAYS_INLINE inline __attribute__((always_inline)) -#else - #define FEA_ALWAYS_INLINE inline -#endif -""" - - # ========================================================================= - # 核心编译接口 - # ========================================================================= - @staticmethod - def compile(model: MathModel, target: str, chunk_size=None, cse_strategy="auto", lowered=None): - """ - 核心分发器:输入 MathModel + target,输出 cpp/cuda/jax/fortran 源码字符串。 - - Args: - model: 数学模型 - target: 目标平台 ('jax', 'cpp', 'cuda', 'fortran') - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - lowered: 预先 lowered 的结果 (可选,用于多后端共享) - """ - target = target.lower() - if target == 'jax': - return FEACompiler._to_jax(model, lowered=lowered, chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target in ['cpp', 'c++']: - return FEACompiler._to_source(model, is_cuda=False, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'cuda': - return FEACompiler._to_source(model, is_cuda=True, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'fortran': - return FEACompiler._to_fortran(model, lowered=lowered, - chunk_size=chunk_size, cse_strategy=cse_strategy) - else: - raise ValueError(f"Unknown target: {target}") - - @staticmethod - def compile_all(model: MathModel, chunk_size=None, cse_strategy="auto", test=False, - task=None, model_name=None): - """ - 一次性生成 jax/cpp/cuda/fortran 四种目标源码。 - - 统一管理 lower 行为: - - 如果所有 target 使用相同的 chunk size,共享一份 lowered - - 如果 JAX 和 cpp/cuda/fortran 使用不同的 chunk size,分别生成 jax_lowered 和 shared_lowered - - Args: - model: 数学模型 - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - test: 是否同时生成测试资产(wrapper、test_driver、build script) - task: CLI 任务类型 ('constitutive', 'stiffness', 'mass', 'custom'),用于 test_driver 重新加载模型 - model_name: 模型/材料/单元名称,用于 test_driver 重新加载模型 - - Returns: - dict: {'jax': code, 'cpp': code, 'cuda': code, 'fortran': code, - 'cpp_wrapper': str, 'f90_wrapper': str, 'test_driver': str, - 'build_sh': str, 'build_bat': str} (后5项仅在 test=True 时存在) - """ - from ci_test.wrappers import generate_cpp_main, generate_f90_main - from ci_test.test_driver_template import generate_test_driver - from ci_test.build_script_generator import generate_build_sh, generate_build_bat - - # 决定各 target 的 chunk size - cpp_chunk = FEACompiler.resolve_chunk_size(model, "cpp", chunk_size, cse_strategy) - jax_chunk = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) - - # 生成 shared lowered 给 cpp/cuda/fortran - shared_lowered = FEACompiler.lower_model(model, cpp_chunk) - - # 决定 JAX 是否共享 lowered - if jax_chunk == cpp_chunk: - jax_lowered = shared_lowered - else: - jax_lowered = FEACompiler.lower_model(model, jax_chunk) - - result = { - "jax": FEACompiler._to_jax(model, lowered=jax_lowered, chunk_size=jax_chunk, cse_strategy=cse_strategy), - "cpp": FEACompiler._to_source(model, is_cuda=False, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - "cuda": FEACompiler._to_source(model, is_cuda=True, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - "fortran": FEACompiler._to_fortran(model, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), - } - - if test: - result["cpp_wrapper"] = generate_cpp_main(model) - result["f90_wrapper"] = generate_f90_main(model) - result["test_driver"] = generate_test_driver(model, task=task, model_name=model_name) - result["build_sh"] = generate_build_sh(model) - result["build_bat"] = generate_build_bat(model) - - return result - - @staticmethod - def _to_jax(model, lowered=None, chunk_size=None, cse_strategy="auto"): - """ - 生成 JAX 源码(.py),采用分块 CSE 优化。 - - Args: - model: 数学模型 - lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - """ - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - lines = [ - '"""Generated by sympy_codegen.py. Do not edit."""', - "import jax.numpy as jnp", - "", - "", - f"def compute_{model.name}(in_flat):", - f' """', - f' Compute the {model.name} kernel.', - f' ', - f' Args:', - f' in_flat: Flattened input array, size {len(model.inputs)}', - f' ', - f' Returns:', - f' Flattened output array, size {len(model.outputs)}', - f' ', - f' Input layout:', - ] - - # 添加输入信息 - for i, name in enumerate(model.input_names): - lines.append(f" ' - in_flat[{i}]: {name}") - - lines.append(f" '") - lines.append(f" ' Output layout:") - - # 添加输出信息 - for i, name in enumerate(model.output_names): - lines.append(f" ' - out[{i}]: {name}") - - lines.append(f' """') - - # Unpack inputs IF they are valid identifiers (e.g. xi, c0) - # If they are like "in[0]", we'll handle them via string replacement later - for i, sym in enumerate(model.inputs): - s = str(sym) - is_ident = s.isidentifier() - # print(f"DEBUG: sym={s}, is_ident={is_ident}") - if is_ident: - lines.append(f" {s} = in_flat[{i}]") - - lines.append("") - - printer = CachedPrinter(JaxPrinter()) - all_simplified_outputs = [] - - # 外部算子调用 - if lowered.external_calls: - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - if op.jax_func is None: - raise ValueError( - f"External operator '{call.op_name}' has no JAX implementation. " - f"Cannot generate JAX code for model '{model.name}'." - ) - lines.append(f" # --- External Operator: {call.op_name} ---") - in_parts = ", ".join(printer.doprint(e) for e in call.input_exprs) - lines.append(f" {call.prefix}_in = jnp.array([{in_parts}])") - lines.append(f" {call.prefix}_out = {op.jax_func}({call.prefix}_in)") - for i, sym in enumerate(call.output_symbols): - lines.append(f" {sym} = {call.prefix}_out[{i}]") - lines.append("") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - for var, expr in chunk.sub_exprs: - lines.append(f" {var} = {printer.doprint(expr)}") - - all_simplified_outputs.extend(chunk.simplified_outputs) - - lines.append("") - lines.append(" # --- Output ---") - out_parts = [printer.doprint(e) for e in all_simplified_outputs] - lines.append(f" return ({','.join(out_parts)})") - - src = "\n".join(lines) - # Final cleanup for JAX and handle C-style inputs - src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") - return src - - @staticmethod - def _to_source(model, is_cuda=False, lowered=None, chunk_size=None, cse_strategy="auto"): - """ - 生成 C++/CUDA 源码,采用分块 CSE 优化及算子化增强。 - - Args: - model: 数学模型 - is_cuda: 是否为 CUDA 目标 - lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) - chunk_size: 用户指定的 chunk size (可选) - cse_strategy: CSE 策略 ('auto' 或 'fixed') - """ - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "cuda" if is_cuda else "cpp", - chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - # --- Generate Comments --- - comment_lines = ["/**"] - comment_lines.append(f" * @brief Computes the {model.name} kernel.") - if model.is_operator: - comment_lines.append(" * @note This is an optimized operator kernel.") - comment_lines.append(" * ") - comment_lines.append(" * @param in Input array (const double*). Layout:") - - for i, name in enumerate(model.input_names): - comment_lines.append(f" * - in[{i}]: {name}") - - comment_lines.append(" * ") - comment_lines.append(" * @param out Output array (double*). Layout:") - - # 列出每个输出的详细信息 - for i, name in enumerate(model.output_names): - comment_lines.append(f" * - out[{i}]: {name}") - - comment_lines.append(" */") - comment_block = "\n".join(comment_lines) - - # --- Generate Function Body --- - body_lines = [] - - # 解包输入变量 - for i, sym in enumerate(model.inputs): - s = str(sym) - # 检查是否是合法标识符(如 coord_2_3),如果是则解包 - if s.isidentifier(): - body_lines.append(f" double {s} = in[{i}];") - - body_lines.append("") - - # 初始化带缓存的 Printer - printer = CachedPrinter(FEACodePrinter()) - - # 外部算子调用 - if lowered.external_calls: - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - body_lines.append(f" // --- External Operator: {call.op_name} ---") - body_lines.append(f" double {call.prefix}_in[{op.n_inputs}];") - body_lines.append(f" double {call.prefix}_out[{op.n_outputs}];") - - for i, expr in enumerate(call.input_exprs): - body_lines.append(f" {call.prefix}_in[{i}] = {printer.doprint(expr)};") - - body_lines.append(f" {op.cpp_func}({call.prefix}_in, {call.prefix}_out);") - - for i, sym in enumerate(call.output_symbols): - body_lines.append(f" double {sym} = {call.prefix}_out[{i}];") - - body_lines.append("") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - body_lines.append(f"\n // --- Chunk {chunk.chunk_index} ---") - - for var, expr in chunk.sub_exprs: - body_lines.append(f" double {var} = {printer.doprint(expr)};") - - for j, out_expr in enumerate(chunk.simplified_outputs): - body_lines.append(f" out[{chunk.start_index + j}] = {printer.doprint(out_expr)};") - - body = "\n".join(body_lines) - - # 统一使用兼容宏体系 - prefix = FEACompiler._cpp_cuda_compat_macros() + "\n" - - if is_cuda: - # CUDA 使用 FEA_DEVICE 宏 - func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" - signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" - else: - # C++ 使用 FEA_ALWAYS_INLINE 宏 - func_type = "FEA_ALWAYS_INLINE void" - signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" - - return f"{prefix}{comment_block}\n{signature} {{ \n{body}\n}}" - - - - @staticmethod - def _to_fortran(model, lowered=None, chunk_size=None, cse_strategy="auto"): - """生成 Fortran 源码,支持分块 CSE 优化。声明和赋值必须分离。""" - # 如果没有提供 lowered 结果,则自行 lower - if lowered is None: - chunk_size = FEACompiler.resolve_chunk_size(model, "fortran", chunk_size, cse_strategy) - lowered = FEACompiler.lower_model(model, chunk_size) - - printer = CachedPrinter(FEAFortranPrinter()) - - def _fortran_declare(type_decl, vars_list, indent=" "): - """Generate Fortran declaration with line continuation if exceeding 120 chars. - Fortran free-format limit is 132 chars; we use 120 for safety margin. - Continuation uses '&' at end of line and '&' at start of continuation. - The comma separator must appear at the end of the line (before &) - so that the continuation line can start cleanly with the next variable. - """ - if not vars_list: - return [] - max_len = 120 - prefix = f"{indent}{type_decl} :: " - # Try single line first - single_line = prefix + ", ".join(vars_list) - if len(single_line) <= max_len: - return [single_line] - # Split across multiple lines with continuation - # Strategy: each line ends with ", &" (comma before ampersand) - # and continuation lines start with "& " then the next variable - result_lines = [] - current = prefix - first = True - for v in vars_list: - # Check if adding this variable (with separator) would exceed limit - if first: - candidate = current + v - else: - candidate = current + ", " + v - if len(candidate) + 2 > max_len and not first: - # End current line with comma + ampersand for continuation - result_lines.append(current + ", &") - current = f"{indent}& {v}" - first = False - else: - current = candidate - first = False - result_lines.append(current) - return result_lines - - lines = [ - "! Generated by sympy_codegen.py. Do not edit.", - "!", - f"! Subroutine: compute_{model.name}", - "!", - "! Input array layout (in_vec):", - ] - - # 添加输入信息 - for i, name in enumerate(model.input_names): - lines.append(f"! in_vec({i + 1}): {name}") - - lines.append("!") - lines.append("! Output array layout (out_vec):") - - # 添加输出信息 - for i, name in enumerate(model.output_names): - lines.append(f"! out_vec({i + 1}): {name}") - - lines.extend([ - "!", - f"subroutine compute_{model.name}(in_vec, out_vec)", - " implicit none", - f" double precision, intent(in) :: in_vec({len(model.inputs)})", - f" double precision, intent(out) :: out_vec({len(model.outputs)})", - " ! --- Unpack inputs ---", - ]) - - # Unpack input array to named variables - input_vars = [] - for i, sym in enumerate(model.inputs): - s = str(sym) - if s.isidentifier(): - input_vars.append(s) - - # First, declare all input variables (with line continuation if needed) - if input_vars: - lines.extend(_fortran_declare("double precision", input_vars, " ")) - - # Then assign values - for i, sym in enumerate(model.inputs): - s = str(sym) - if s.isidentifier(): - lines.append(f" {s} = in_vec({i + 1})") - - # 外部算子调用 - if lowered.external_calls: - lines.append("") - lines.append(" ! --- External Operator Calls ---") - # 先声明所有外部算子相关的变量 - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - lines.extend(_fortran_declare("double precision", - [f"{call.prefix}_in({op.n_inputs})", f"{call.prefix}_out({op.n_outputs})"], " ")) - out_var_names = [str(sym) for sym in call.output_symbols] - lines.extend(_fortran_declare("double precision", out_var_names, " ")) - - # 然后赋值和调用 - for call in lowered.external_calls: - op = lowered.external_ops[call.op_name] - lines.append(f" ! External Operator: {call.op_name}") - for i, expr in enumerate(call.input_exprs): - lines.append(f" {call.prefix}_in({i + 1}) = {printer.doprint(expr)}") - lines.append(f" call {op.fortran_func}({call.prefix}_in, {call.prefix}_out)") - for i, sym in enumerate(call.output_symbols): - lines.append(f" {sym} = {call.prefix}_out({i + 1})") - - lines.append("") - - lines.append(" ! --- Local Variables for CSE ---") - - # 使用 lowered 结果 - for chunk in lowered.chunks: - lines.append(f" ! Chunk {chunk.chunk_index}") - if chunk.sub_exprs: - lines.append(" block") - - # Separate variables by type: logical for comparisons, double precision otherwise - dp_vars = [] - log_vars = [] - for var, expr in chunk.sub_exprs: - if isinstance(expr, Relational): - log_vars.append(str(var)) - else: - dp_vars.append(str(var)) - - if dp_vars: - lines.extend(_fortran_declare("double precision", dp_vars, " ")) - if log_vars: - lines.extend(_fortran_declare("logical", log_vars, " ")) - - # Then assign values - for var, expr in chunk.sub_exprs: - lines.append(f" {var} = {printer.doprint(expr)}") - - for j, out_expr in enumerate(chunk.simplified_outputs): - # Fortran arrays are 1-based. - lines.append(f" out_vec({chunk.start_index + j + 1}) = {printer.doprint(out_expr)}") - - if chunk.sub_exprs: - lines.append(" end block") - - lines.append(f"end subroutine compute_{model.name}") - src = "\n".join(lines) - # Replace C-style array access in[i] with Fortran 1-based in_vec(i+1) - # This handles SymPy symbols like in[0], in[1] that are not valid identifiers - def _replace_in_array(m): - idx = int(m.group(1)) - return f"in_vec({idx + 1})" - src = re.sub(r'in\[(\d+)\]', _replace_in_array, src) - return src - - @staticmethod - def compile_element(element: Element, target: str, chunk_size=None, cse_strategy="auto"): - """ - Special compiler for Elements: supports both single-kernel and operator-based generation. - """ - operators = element.get_stiffness_operators() - if operators: - # Generate multiple operator kernels - generated = {} - for op_model in operators: - generated[op_model.name] = FEACompiler.compile(op_model, target, - chunk_size=chunk_size, cse_strategy=cse_strategy) - return generated - else: - # Traditional single kernel - model = element.get_stiffness_model() - return {model.name: FEACompiler.compile(model, target, - chunk_size=chunk_size, cse_strategy=cse_strategy)} - - # ========================================================================= - # FlowModel 编译 — 命令式主流程代码生成 - # ========================================================================= - @staticmethod - def compile_flow(flow: FlowModel, target: str, chunk_size=None, cse_strategy="auto"): - """ - 编译 FlowModel,生成主流程函数 + 所有子模型函数的完整源码。 - - Args: - flow: FlowModel 实例 - target: 目标平台 ('cpp', 'cuda', 'fortran', 'jax') - chunk_size: 子模型 CSE chunk size (可选) - cse_strategy: CSE 策略 - - Returns: - str: 完整源码(包含子模型函数 + 主流程函数) - """ - target = target.lower() - if target in ('cpp', 'c++', 'cuda'): - is_cuda = (target == 'cuda') - return FEACompiler._flow_to_source(flow, is_cuda=is_cuda, - chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'fortran': - return FEACompiler._flow_to_fortran(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) - elif target == 'jax': - return FEACompiler._flow_to_jax(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) - else: - raise ValueError(f"FlowModel does not support target '{target}' yet") - - # ------------------------------------------------------------------------- - # C++/CUDA Flow 代码生成 - # ------------------------------------------------------------------------- - @staticmethod - def _flow_to_source(flow: FlowModel, is_cuda=False, chunk_size=None, cse_strategy="auto"): - """生成 FlowModel 的 C++/CUDA 源码""" - - # 1. 编译所有子模型 → compute_xxx 函数源码(剥离重复的宏定义) - macros_str = FEACompiler._cpp_cuda_compat_macros() - sub_sources = [] - for name, sub_model in flow.submodels.items(): - sub_src = FEACompiler.compile(sub_model, "cuda" if is_cuda else "cpp", - chunk_size=chunk_size, cse_strategy=cse_strategy) - # 剥离子模型源码中的宏定义(避免重复) - sub_src = sub_src.replace(macros_str, "").strip() - sub_sources.append(sub_src) - - # 2. 构建缓冲区查找表 - buffer_map = {b.name: b for b in flow.local_buffers} - - # 3. 收集标量变量(Call 输出中不在 buffer_map 的) - # 需要先遍历 body 收集 - scalar_vars = set() - FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) - - # 4. 生成主流程函数 - printer = CachedPrinter(FEACodePrinter()) - - # --- 函数注释 --- - comment_lines = ["/**"] - comment_lines.append(f" * @brief Flow kernel: {flow.name}") - comment_lines.append(" * ") - comment_lines.append(" * @param in Input array (const double*). Layout:") - for i, sym in enumerate(flow.inputs): - comment_lines.append(f" * - in[{i}]: {sym}") - comment_lines.append(" * ") - comment_lines.append(" * @param out Output array (double*). Layout:") - offset = 0 - for out_name in flow.outputs: - s = str(out_name) - if s in buffer_map: - buf = buffer_map[s] - comment_lines.append(f" * - out[{offset}..{offset + buf.size - 1}]: {s} (buffer, size={buf.size})") - offset += buf.size - else: - comment_lines.append(f" * - out[{offset}]: {s}") - offset += 1 - comment_lines.append(" */") - comment_block = "\n".join(comment_lines) - - # --- 函数体 --- - body_lines = [] - - # 解包输入 - for i, sym in enumerate(flow.inputs): - s = str(sym) - if s.isidentifier(): - body_lines.append(f" double {s} = in[{i}];") - body_lines.append("") - - # 声明缓冲区 - for buf in flow.local_buffers: - dtype = "double" if buf.dtype == "double" else buf.dtype - body_lines.append(f" {dtype} {buf.name}[{buf.size}];") - body_lines.append("") - - # 声明标量变量 - if scalar_vars: - for var in sorted(scalar_vars): - body_lines.append(f" double {var};") - body_lines.append("") - - # 生成 body - FEACompiler._emit_body(flow.body, buffer_map, printer, body_lines, indent=1) - - # 输出映射(按累积偏移) - body_lines.append("") - body_lines.append(" // --- Output ---") - offset = 0 - for out_name in flow.outputs: - s = str(out_name) - if s in buffer_map: - buf = buffer_map[s] - for j in range(buf.size): - body_lines.append(f" out[{offset + j}] = {s}[{j}];") - offset += buf.size - else: - body_lines.append(f" out[{offset}] = {s};") - offset += 1 - - body = "\n".join(body_lines) - - # --- 函数签名 --- - prefix_macros = FEACompiler._cpp_cuda_compat_macros() + "\n" - if is_cuda: - func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" - else: - func_type = "FEA_ALWAYS_INLINE void" - signature = f"{func_type} compute_{flow.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" - - main_func = f"{comment_block}\n{signature} {{\n{body}\n}}" - - # 组装:宏定义(1次) + 子模型函数 + 主流程函数 - parts = [prefix_macros] + sub_sources + [main_func] - return "\n\n".join(parts) - - # ------------------------------------------------------------------------- - # Fortran Flow 代码生成 - # ------------------------------------------------------------------------- - @staticmethod - def _flow_to_fortran(flow: FlowModel, chunk_size=None, cse_strategy="auto"): - """生成 FlowModel 的 Fortran 源码""" - - # 1. 编译所有子模型 - sub_sources = [] - for name, sub_model in flow.submodels.items(): - sub_src = FEACompiler.compile(sub_model, "fortran", - chunk_size=chunk_size, cse_strategy=cse_strategy) - sub_sources.append(sub_src) - - # 2. 构建查找表 - buffer_map = {b.name: b for b in flow.local_buffers} - scalar_vars = set() - FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) - # 收集 For 循环的 index 变量(Fortran 需声明为 integer) - for_indices = set() - FEACompiler._collect_for_indices(flow.body, for_indices) - - # 3. 生成主流程 - printer = CachedPrinter(FEAFortranPrinter()) - - lines = [ - "! Generated by sympy_codegen.py. Do not edit.", - "!", - f"! Flow kernel: compute_{flow.name}", - "!", - f"subroutine compute_{flow.name}(in_vec, out_vec)", - " implicit none", - f" double precision, intent(in) :: in_vec({len(flow.inputs)})", - f" double precision, intent(out) :: out_vec({sum(buffer_map[str(o)].size if str(o) in buffer_map else 1 for o in flow.outputs)})", - ] - - # 解包输入 - input_vars = [] - for i, sym in enumerate(flow.inputs): - s = str(sym) - if s.isidentifier(): - input_vars.append(s) - - if input_vars: - lines.append(" ! --- Unpack inputs ---") - lines.append(f" double precision :: {', '.join(input_vars)}") - for i, sym in enumerate(flow.inputs): - s = str(sym) - if s.isidentifier(): - lines.append(f" {s} = in_vec({i + 1})") - lines.append("") - - # 声明缓冲区 - if flow.local_buffers: - lines.append(" ! --- Local Buffers ---") - for buf in flow.local_buffers: - lines.append(f" double precision :: {buf.name}({buf.size})") - lines.append("") - - # 声明标量变量 - if scalar_vars: - lines.append(" ! --- Local Scalars ---") - lines.append(f" double precision :: {', '.join(sorted(scalar_vars))}") - lines.append("") - - # 声明 For 循环 index 变量 - if for_indices: - lines.append(" ! --- Loop Indices ---") - lines.append(f" integer :: {', '.join(sorted(for_indices))}") - lines.append("") - - # 生成 body - FEACompiler._emit_body_fortran(flow.body, buffer_map, printer, lines, indent=1) - - # 输出映射 - lines.append("") - lines.append(" ! --- Output ---") - offset = 0 - for out_name in flow.outputs: - s = str(out_name) - if s in buffer_map: - buf = buffer_map[s] - for j in range(buf.size): - lines.append(f" out_vec({offset + j + 1}) = {s}({j + 1})") - offset += buf.size - else: - lines.append(f" out_vec({offset + 1}) = {s}") - offset += 1 - - lines.append(f"end subroutine compute_{flow.name}") - - # 组装 - parts = sub_sources + ["\n".join(lines)] - return "\n\n".join(parts) - - # ========================================================================= - # 辅助方法:标量变量收集 - # ========================================================================= - @staticmethod - def _collect_scalar_vars(body, buffer_map, scalar_vars): - """递归收集 body 中 Call 输出的标量变量名""" - for stmt in body: - if isinstance(stmt, Call): - for var_name in stmt.output_vars: - if var_name not in buffer_map: - scalar_vars.add(var_name) - elif isinstance(stmt, If): - FEACompiler._collect_scalar_vars(stmt.then_body, buffer_map, scalar_vars) - FEACompiler._collect_scalar_vars(stmt.else_body, buffer_map, scalar_vars) - elif isinstance(stmt, For): - FEACompiler._collect_scalar_vars(stmt.body, buffer_map, scalar_vars) - - @staticmethod - def _collect_for_indices(body, indices): - """递归收集 For 语句的 index 变量名""" - for stmt in body: - if isinstance(stmt, For): - indices.add(str(stmt.index)) - FEACompiler._collect_for_indices(stmt.body, indices) - elif isinstance(stmt, If): - FEACompiler._collect_for_indices(stmt.then_body, indices) - FEACompiler._collect_for_indices(stmt.else_body, indices) - - # ========================================================================= - # 辅助方法:C++/CUDA 语句生成 - # ========================================================================= - @staticmethod - def _emit_body(body, buffer_map, printer, lines, indent=1): - """递归生成 C++/CUDA 语句""" - pad = " " * indent - for stmt in body: - if isinstance(stmt, Assign): - target = str(stmt.target) - expr_str = FEACompiler._print_expr(stmt.expr, printer) - lines.append(f"{pad}{target} = {expr_str};") - - elif isinstance(stmt, BufferFill): - buf = buffer_map[stmt.target] - fill_val = stmt.value - lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {fill_val};") - - elif isinstance(stmt, BufferCopy): - buf = buffer_map[stmt.target] - lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {stmt.source}[_i];") - - elif isinstance(stmt, BufferAccum): - buf = buffer_map[stmt.target] - lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] += {stmt.source}[_i];") - - elif isinstance(stmt, Call): - model_name = stmt.model_name - # 计算子模型的总输入/输出数量 - # input_exprs 中可能有标量引用或 SymPy 表达式 - # 当 input_exprs 元素是 str 且在 buffer_map 中时,需要逐元素填充 - lines.append(f"{pad}// Call: {model_name}") - # 计算总输入元素数 - total_in = 0 - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - total_in += buffer_map[e].size - else: - total_in += 1 - - # 计算总输出元素数(buffer 展开为逐元素) - total_out = 0 - for var_name in stmt.output_vars: - if var_name in buffer_map: - total_out += buffer_map[var_name].size - else: - total_out += 1 - - lines.append(f"{pad}{{") - lines.append(f"{pad} double _call_in[{total_in}];") - # 填充输入 - in_idx = 0 - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - buf = buffer_map[e] - for j in range(buf.size): - lines.append(f"{pad} _call_in[{in_idx}] = {e}[{j}];") - in_idx += 1 - else: - lines.append(f"{pad} _call_in[{in_idx}] = {FEACompiler._print_expr(e, printer)};") - in_idx += 1 - - # 输出:构造统一的 _call_out 数组,然后按累积偏移分发 - lines.append(f"{pad} double _call_out[{total_out}];") - lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") - - # 按累积偏移分发输出 - offset = 0 - for var_name in stmt.output_vars: - if var_name in buffer_map: - buf = buffer_map[var_name] - for j in range(buf.size): - lines.append(f"{pad} {var_name}[{j}] = _call_out[{offset + j}];") - offset += buf.size - else: - lines.append(f"{pad} {var_name} = _call_out[{offset}];") - offset += 1 - - lines.append(f"{pad}}}") - - elif isinstance(stmt, If): - cond_str = FEACompiler._print_expr(stmt.cond, printer) - lines.append(f"{pad}if ({cond_str}) {{") - FEACompiler._emit_body(stmt.then_body, buffer_map, printer, lines, indent + 1) - if stmt.else_body: - lines.append(f"{pad}}} else {{") - FEACompiler._emit_body(stmt.else_body, buffer_map, printer, lines, indent + 1) - lines.append(f"{pad}}}") - - elif isinstance(stmt, For): - idx = str(stmt.index) - start = stmt.start - end = stmt.end - - if stmt.unroll: - # 展开循环 - for i in range(int(start), int(end)): - lines.append(f"{pad}// Unrolled iteration {idx} = {i}") - # 替换 body 中引用 index 的表达式 - sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) - FEACompiler._emit_body(sub_body, buffer_map, printer, lines, indent) - else: - lines.append(f"{pad}for (int {idx} = {start}; {idx} < {end}; {idx}++) {{") - FEACompiler._emit_body(stmt.body, buffer_map, printer, lines, indent + 1) - lines.append(f"{pad}}}") - - # ========================================================================= - # 辅助方法:Fortran 语句生成 - # ========================================================================= - @staticmethod - def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): - """递归生成 Fortran 语句""" - pad = " " * indent - for stmt in body: - if isinstance(stmt, Assign): - target = str(stmt.target) - expr_str = FEACompiler._print_expr(stmt.expr, printer) - lines.append(f"{pad}{target} = {expr_str}") - - elif isinstance(stmt, BufferFill): - buf = buffer_map[stmt.target] - fill_val = stmt.value - if fill_val == 0.0: - lines.append(f"{pad}{stmt.target}(:) = 0.0d0") - else: - lines.append(f"{pad}{stmt.target}(:) = {fill_val}d0") - - elif isinstance(stmt, BufferCopy): - lines.append(f"{pad}{stmt.target}(:) = {stmt.source}(:)") - - elif isinstance(stmt, BufferAccum): - buf = buffer_map[stmt.target] - lines.append(f"{pad}{stmt.target}(:) = {stmt.target}(:) + {stmt.source}(:)") - - elif isinstance(stmt, Call): - model_name = stmt.model_name - - # 计算总输入元素数(buffer 展开为逐元素) - total_in = 0 - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - total_in += buffer_map[e].size - else: - total_in += 1 - - lines.append(f"{pad}! Call: {model_name}") - lines.append(f"{pad}block") - lines.append(f"{pad} double precision :: _call_in({total_in})") - # 计算总输出元素数(buffer 展开为逐元素) - total_out = 0 - for var_name in stmt.output_vars: - if var_name in buffer_map: - total_out += buffer_map[var_name].size - else: - total_out += 1 - lines.append(f"{pad} double precision :: _call_out({total_out})") - - in_idx = 0 - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - buf = buffer_map[e] - for j in range(buf.size): - lines.append(f"{pad} _call_in({in_idx + 1}) = {e}({j + 1})") - in_idx += 1 - else: - lines.append(f"{pad} _call_in({in_idx + 1}) = {FEACompiler._print_expr(e, printer)}") - in_idx += 1 - - lines.append(f"{pad} call compute_{model_name}(_call_in, _call_out)") - - # 按累积偏移分发输出 - offset = 0 - for var_name in stmt.output_vars: - if var_name in buffer_map: - buf = buffer_map[var_name] - for j in range(buf.size): - lines.append(f"{pad} {var_name}({j + 1}) = _call_out({offset + j + 1})") - offset += buf.size - else: - lines.append(f"{pad} {var_name} = _call_out({offset + 1})") - offset += 1 - - lines.append(f"{pad}end block") - - elif isinstance(stmt, If): - cond_str = FEACompiler._print_expr(stmt.cond, printer) - lines.append(f"{pad}if ({cond_str}) then") - FEACompiler._emit_body_fortran(stmt.then_body, buffer_map, printer, lines, indent + 1) - if stmt.else_body: - lines.append(f"{pad}else") - FEACompiler._emit_body_fortran(stmt.else_body, buffer_map, printer, lines, indent + 1) - lines.append(f"{pad}end if") - - elif isinstance(stmt, For): - idx = str(stmt.index) - start = stmt.start - end = stmt.end - - if stmt.unroll: - for i in range(int(start), int(end)): - lines.append(f"{pad}! Unrolled iteration {idx} = {i}") - sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) - FEACompiler._emit_body_fortran(sub_body, buffer_map, printer, lines, indent) - else: - lines.append(f"{pad}do {idx} = {start}, {end} - 1") - FEACompiler._emit_body_fortran(stmt.body, buffer_map, printer, lines, indent + 1) - lines.append(f"{pad}end do") - - # ========================================================================= - # 辅助方法:表达式打印 & 索引替换 - # ========================================================================= - @staticmethod - def _print_expr(expr, printer): - """将 SymPy 表达式或原始值打印为字符串""" - if isinstance(expr, (int, float)): - return str(expr) - if isinstance(expr, str): - return expr - if isinstance(expr, sp.Basic): - return printer.doprint(expr) - return str(expr) - - @staticmethod - def _substitute_index(body, index_sym, value): - """将 body 中所有引用 index_sym 的表达式替换为具体值,返回新的 body 列表""" - import copy - new_body = [] - for stmt in body: - if isinstance(stmt, Assign): - new_expr = stmt.expr - new_target = stmt.target - if isinstance(new_expr, sp.Basic): - new_expr = new_expr.subs(index_sym, value) - if isinstance(new_target, sp.Basic): - new_target = new_target.subs(index_sym, value) - new_body.append(Assign(new_target, new_expr)) - - elif isinstance(stmt, Call): - new_input_exprs = [] - for e in stmt.input_exprs: - if isinstance(e, sp.Basic): - new_input_exprs.append(e.subs(index_sym, value)) - else: - new_input_exprs.append(e) - new_body.append(Call(stmt.model_name, new_input_exprs, stmt.output_vars)) - - elif isinstance(stmt, If): - new_cond = stmt.cond - if isinstance(new_cond, sp.Basic): - new_cond = new_cond.subs(index_sym, value) - new_then = FEACompiler._substitute_index(stmt.then_body, index_sym, value) - new_else = FEACompiler._substitute_index(stmt.else_body, index_sym, value) - new_body.append(If(new_cond, new_then, new_else)) - - elif isinstance(stmt, For): - # 不替换嵌套 For 的 index(不同循环变量),但替换 body 内引用外层 index 的部分 - new_start = stmt.start.subs(index_sym, value) if isinstance(stmt.start, sp.Basic) else stmt.start - new_end = stmt.end.subs(index_sym, value) if isinstance(stmt.end, sp.Basic) else stmt.end - new_body_inner = FEACompiler._substitute_index(stmt.body, index_sym, value) - new_body.append(For(stmt.index, new_start, new_end, new_body_inner, stmt.unroll)) - - else: - # BufferFill, BufferCopy, BufferAccum — 不含表达式,直接拷贝 - new_body.append(stmt) - - return new_body - - # ========================================================================= - # JAX Flow 代码生成 - # ========================================================================= - @staticmethod - def _flow_to_jax(flow: FlowModel, chunk_size=None, cse_strategy="auto"): - """ - 生成 FlowModel 的 JAX 源码。 - - JAX 是纯函数式的,生成策略: - - 缓冲区 → jnp.zeros / jnp.array - - BufferAccum → buf = buf.at[i].add(src[i]) 或 buf = buf + src - - For(unroll=False) → jax.lax.fori_loop - - For(unroll=True) → Python 展开循环 - - If → jax.lax.cond (then/else 必须返回相同结构的 tuple) - - Call → 调用子模型函数 - """ - - # 1. 编译所有子模型 - sub_sources = [] - for name, sub_model in flow.submodels.items(): - sub_src = FEACompiler.compile(sub_model, "jax", - chunk_size=chunk_size, cse_strategy=cse_strategy) - sub_sources.append(sub_src) - - # 2. 构建查找表 - buffer_map = {b.name: b for b in flow.local_buffers} - scalar_vars = set() - FEACompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) - for_indices = set() - FEACompiler._collect_for_indices(flow.body, for_indices) - - # 3. 生成主流程 - from sympy.printing.numpy import JaxPrinter - printer = CachedPrinter(JaxPrinter()) - - lines = [ - '"""Generated by sympy_codegen.py. Do not edit."""', - "import jax", - "import jax.numpy as jnp", - "", - "", - f"def compute_{flow.name}(in_flat):", - f' """', - f' Flow kernel: {flow.name}', - f' ', - f' Args:', - f' in_flat: Flattened input array, size {len(flow.inputs)}', - f' ', - f' Returns:', - f' Flattened output array', - f' """', - ] - - # 解包输入 - for i, sym in enumerate(flow.inputs): - s = str(sym) - if s.isidentifier(): - lines.append(f" {s} = in_flat[{i}]") - lines.append("") - - # 初始化缓冲区 - for buf in flow.local_buffers: - if buf.dtype == "double": - lines.append(f" {buf.name} = jnp.zeros({buf.size})") - else: - lines.append(f" {buf.name} = jnp.zeros({buf.size}, dtype=jnp.{buf.dtype})") - lines.append("") - - # 声明标量变量初始值(用于 JAX 的函数式风格) - # JAX 中标量不需要预声明,在赋值时绑定即可 - - # 生成 body - FEACompiler._emit_body_jax(flow.body, buffer_map, printer, lines, indent=1, for_indices=for_indices) - - # 输出映射 - lines.append("") - lines.append(" # --- Output ---") - out_parts = [] - for out_name in flow.outputs: - s = str(out_name) - if s in buffer_map: - out_parts.append(s) - else: - out_parts.append(s) - - if len(out_parts) == 1: - lines.append(f" return {out_parts[0]}") - else: - lines.append(f" return jnp.concatenate([{', '.join(out_parts)}])") - - src = "\n".join(sub_sources) + "\n\n\n" + "\n".join(lines) - src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") - return src - - # ------------------------------------------------------------------------- - # JAX 语句生成 - # ------------------------------------------------------------------------- - @staticmethod - def _emit_body_jax(body, buffer_map, printer, lines, indent=1, for_indices=None): - """递归生成 JAX 语句""" - pad = " " * indent - for stmt in body: - if isinstance(stmt, Assign): - target = str(stmt.target) - expr_str = FEACompiler._print_expr(stmt.expr, printer) - lines.append(f"{pad}{target} = {expr_str}") - - elif isinstance(stmt, BufferFill): - buf = buffer_map[stmt.target] - fill_val = stmt.value - if fill_val == 0.0: - lines.append(f"{pad}{stmt.target} = jnp.zeros({buf.size})") - else: - lines.append(f"{pad}{stmt.target} = jnp.full({buf.size}, {fill_val})") - - elif isinstance(stmt, BufferCopy): - lines.append(f"{pad}{stmt.target} = {stmt.source}.copy()") - - elif isinstance(stmt, BufferAccum): - buf_target = buffer_map[stmt.target] - buf_source = buffer_map[stmt.source] - if buf_target.size == buf_source.size: - # 同尺寸:向量加法 - lines.append(f"{pad}{stmt.target} = {stmt.target} + {stmt.source}") - else: - # 不同尺寸:逐元素 at[].add()(不太常见,但保留安全路径) - lines.append(f"{pad}{stmt.target} = {stmt.target}.at[:{buf_source.size}].add({stmt.source})") - - elif isinstance(stmt, Call): - model_name = stmt.model_name - - # 计算总输入元素数 - total_in = 0 - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - total_in += buffer_map[e].size - else: - total_in += 1 - - # 构建输入 - in_parts = [] - for e in stmt.input_exprs: - if isinstance(e, str) and e in buffer_map: - in_parts.append(e) - else: - in_parts.append(FEACompiler._print_expr(e, printer)) - - lines.append(f"{pad}# Call: {model_name}") - if len(in_parts) == 1: - # 单个输入,可能是标量或数组 - lines.append(f"{pad}_call_in = jnp.array([{in_parts[0]}]) if not isinstance({in_parts[0]}, jnp.ndarray) else {in_parts[0]}.reshape(-1)") - else: - # 多个输入,拼接为 1D 数组 - items = ", ".join(in_parts) - lines.append(f"{pad}_call_in = jnp.concatenate([jnp.atleast_1d(jnp.asarray(x)) for x in [{items}]])") - - lines.append(f"{pad}_call_out = compute_{model_name}(_call_in)") - - # 分发输出 - offset = 0 - for var_name in stmt.output_vars: - if var_name in buffer_map: - buf = buffer_map[var_name] - lines.append(f"{pad}{var_name} = _call_out[{offset}:{offset + buf.size}]") - offset += buf.size - else: - lines.append(f"{pad}{var_name} = _call_out[{offset}]") - offset += 1 - - elif isinstance(stmt, If): - # JAX: jax.lax.cond(pred, true_fun, false_fun, operand) - # 但生成 jax.lax.cond 的完整函数定义太复杂, - # 如果 body 中只有简单语句,用 jnp.where 更实用 - # 否则用 jax.lax.cond - - # 策略:检测 then_body 和 else_body 的复杂度 - # 如果两者都是纯赋值/BufferAccum,用 jnp.where 系列 - # 如果有 Call/For/嵌套 If,用 jax.lax.cond - - is_simple = FEACompiler._is_simple_jax_if(stmt) - - if is_simple: - # 简单 If:逐语句生成 jnp.where 版本 - FEACompiler._emit_simple_if_jax(stmt, buffer_map, printer, lines, indent) - else: - # 复杂 If:用 jax.lax.cond - FEACompiler._emit_cond_if_jax(stmt, buffer_map, printer, lines, indent) - - elif isinstance(stmt, For): - idx = str(stmt.index) - start = stmt.start - end = stmt.end - - if stmt.unroll: - # 展开循环 - for i in range(int(start), int(end)): - lines.append(f"{pad}# Unrolled iteration {idx} = {i}") - sub_body = FEACompiler._substitute_index(stmt.body, stmt.index, i) - FEACompiler._emit_body_jax(sub_body, buffer_map, printer, lines, indent, for_indices) - else: - # jax.lax.fori_loop - # 需要把循环体封装为一个函数 - # 携带状态 = 所有缓冲区 + 循环体内修改的标量 - carried = FEACompiler._collect_carried_vars(stmt.body, buffer_map) - - if carried: - carry_names = sorted(carried) - carry_tuple = ", ".join(carry_names) - - # 生成循环体函数 - lines.append(f"{pad}def _for_body_{idx}({idx}, _carry):") - for i, name in enumerate(carry_names): - lines.append(f"{pad} {name} = _carry[{i}]") - - # 生成循环体 - FEACompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent + 1, for_indices) - - # 返回 carry - ret_parts = ", ".join(carry_names) - lines.append(f"{pad} return ({ret_parts},)") - lines.append("") - - # 调用 fori_loop - init_parts = ", ".join(carry_names) - if len(carry_names) == 1: - # 单元素 carry: fori_loop 返回 (val,), 需要 [0] 解包 - lines.append(f"{pad}{carry_tuple}, = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") - else: - lines.append(f"{pad}{carry_tuple} = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") - else: - # 无携带状态,循环体无副作用,只需调用一次 - lines.append(f"{pad}# For loop with no carried state — body executed once") - FEACompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent, for_indices) - - # ------------------------------------------------------------------------- - # JAX If 辅助方法 - # ------------------------------------------------------------------------- - @staticmethod - def _is_simple_jax_if(if_stmt): - """判断 If 语句是否足够简单,可以用 jnp.where 实现""" - def _is_simple_body(body): - for s in body: - if isinstance(s, (Call, For, If)): - return False - if isinstance(s, BufferAccum): - return False # BufferAccum 需要 += 语义 - return True - - return _is_simple_body(if_stmt.then_body) and _is_simple_body(if_stmt.else_body) - - @staticmethod - def _emit_simple_if_jax(if_stmt, buffer_map, printer, lines, indent): - """用 jnp.where 生成简单 If""" - pad = " " * indent - cond_str = FEACompiler._print_expr(if_stmt.cond, printer) - - # 先生成 then_body 的赋值 - for s in if_stmt.then_body: - if isinstance(s, Assign): - target = str(s.target) - then_val = FEACompiler._print_expr(s.expr, printer) - # 从 else_body 找同名赋值,或用原值 - else_val = target # 默认:不变 - for es in if_stmt.else_body: - if isinstance(es, Assign) and str(es.target) == target: - else_val = FEACompiler._print_expr(es.expr, printer) - break - if target in buffer_map: - lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") - else: - lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") - - elif isinstance(s, BufferFill): - buf = buffer_map[s.target] - fill_val = s.value - if fill_val == 0.0: - lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.zeros({buf.size}), {s.target})") - else: - lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.full({buf.size}, {fill_val}), {s.target})") - - elif isinstance(s, BufferCopy): - lines.append(f"{pad}{s.target} = jnp.where({cond_str}, {s.source}.copy(), {s.target})") - - # else_body 中独有的赋值 - then_targets = {str(s.target) for s in if_stmt.then_body if isinstance(s, (Assign, BufferFill, BufferCopy))} - for s in if_stmt.else_body: - if isinstance(s, Assign) and str(s.target) not in then_targets: - target = str(s.target) - else_val = FEACompiler._print_expr(s.expr, printer) - lines.append(f"{pad}{target} = jnp.where({cond_str}, {target}, {else_val})") - - @staticmethod - def _emit_cond_if_jax(if_stmt, buffer_map, printer, lines, indent): - """用 jax.lax.cond 生成复杂 If""" - pad = " " * indent - cond_str = FEACompiler._print_expr(if_stmt.cond, printer) - - # 收集 then/else 修改的变量 - then_carried = FEACompiler._collect_carried_vars(if_stmt.then_body, buffer_map) - else_carried = FEACompiler._collect_carried_vars(if_stmt.else_body, buffer_map) - carried = sorted(then_carried | else_carried) - - if not carried: - # 无副作用,直接生成 then_body(条件满足时执行) - # 但 JAX 的 cond 需要两边都有返回值 - # 简化:直接生成 then_body - FEACompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent) - return - - carry_tuple = ", ".join(carried) - - # then 函数 - lines.append(f"{pad}def _if_true(_carry):") - for i, name in enumerate(carried): - lines.append(f"{pad} {name} = _carry[{i}]") - FEACompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent + 1) - ret_parts = ", ".join(carried) - lines.append(f"{pad} return ({ret_parts},)") - lines.append("") - - # else 函数 - lines.append(f"{pad}def _if_false(_carry):") - for i, name in enumerate(carried): - lines.append(f"{pad} {name} = _carry[{i}]") - if if_stmt.else_body: - FEACompiler._emit_body_jax(if_stmt.else_body, buffer_map, printer, lines, indent + 1) - # else 必须和 then 返回相同结构 - lines.append(f"{pad} return ({ret_parts},)") - lines.append("") - - # 调用 jax.lax.cond - init_parts = ", ".join(carried) - if len(carried) == 1: - lines.append(f"{pad}{carry_tuple}, = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") - else: - lines.append(f"{pad}{carry_tuple} = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") - - # ------------------------------------------------------------------------- - # JAX 携带变量收集 - # ------------------------------------------------------------------------- - @staticmethod - def _collect_carried_vars(body, buffer_map): - """收集 body 中被修改的变量名(需要作为 fori_loop/cond 的 carry)""" - carried = set() - for stmt in body: - if isinstance(stmt, Assign): - carried.add(str(stmt.target)) - elif isinstance(stmt, BufferFill): - carried.add(stmt.target) - elif isinstance(stmt, BufferCopy): - carried.add(stmt.target) - elif isinstance(stmt, BufferAccum): - carried.add(stmt.target) - elif isinstance(stmt, Call): - for var_name in stmt.output_vars: - carried.add(var_name) - elif isinstance(stmt, If): - then_c = FEACompiler._collect_carried_vars(stmt.then_body, buffer_map) - else_c = FEACompiler._collect_carried_vars(stmt.else_body, buffer_map) - carried |= then_c | else_c - elif isinstance(stmt, For): - carried |= FEACompiler._collect_carried_vars(stmt.body, buffer_map) - return carried +class FEACompiler(MathCompiler): + """向后兼容的编译器门面,同时提供 Math 和 Flow 编译能力。""" + compile_flow = staticmethod(FlowCompiler.compile_flow) diff --git a/codegen/flow_compiler.py b/codegen/flow_compiler.py new file mode 100644 index 0000000..339a2af --- /dev/null +++ b/codegen/flow_compiler.py @@ -0,0 +1,841 @@ +import sympy as sp +from sympy.printing.numpy import JaxPrinter + +from codegen.model import ( + FlowModel, Assign, BufferFill, BufferCopy, BufferAccum, + Call, If, For, BufferRef, +) +from codegen.lowered import CachedPrinter +from codegen.printer import FEACodePrinter, FEAFortranPrinter +from codegen.math_compiler import MathCompiler + + +class FlowCompiler: + # ========================================================================= + # FlowModel 编译 — 命令式主流程代码生成 + # ========================================================================= + @staticmethod + def compile_flow(flow: FlowModel, target: str, chunk_size=None, cse_strategy="auto"): + """ + 编译 FlowModel,生成主流程函数 + 所有子模型函数的完整源码。 + + Args: + flow: FlowModel 实例 + target: 目标平台 ('cpp', 'cuda', 'fortran', 'jax') + chunk_size: 子模型 CSE chunk size (可选) + cse_strategy: CSE 策略 + + Returns: + str: 完整源码(包含子模型函数 + 主流程函数) + """ + target = target.lower() + if target in ('cpp', 'c++', 'cuda'): + is_cuda = (target == 'cuda') + return FlowCompiler._flow_to_source(flow, is_cuda=is_cuda, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'fortran': + return FlowCompiler._flow_to_fortran(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'jax': + return FlowCompiler._flow_to_jax(flow, chunk_size=chunk_size, cse_strategy=cse_strategy) + else: + raise ValueError(f"FlowModel does not support target '{target}' yet") + + # ------------------------------------------------------------------------- + # C++/CUDA Flow 代码生成 + # ------------------------------------------------------------------------- + @staticmethod + def _flow_to_source(flow: FlowModel, is_cuda=False, chunk_size=None, cse_strategy="auto"): + """生成 FlowModel 的 C++/CUDA 源码""" + + # 1. 编译所有子模型 → compute_xxx 函数源码(剥离重复的宏定义) + macros_str = MathCompiler._cpp_cuda_compat_macros() + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = MathCompiler.compile(sub_model, "cuda" if is_cuda else "cpp", + chunk_size=chunk_size, cse_strategy=cse_strategy) + # 剥离子模型源码中的宏定义(避免重复) + sub_src = sub_src.replace(macros_str, "").strip() + sub_sources.append(sub_src) + + # 2. 构建缓冲区查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + + # 3. 收集标量变量(Call 输出中不在 buffer_map 的) + # 需要先遍历 body 收集 + scalar_vars = set() + FlowCompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + + # 4. 生成主流程函数 + printer = CachedPrinter(FEACodePrinter()) + + # --- 函数注释 --- + comment_lines = ["/**"] + comment_lines.append(f" * @brief Flow kernel: {flow.name}") + comment_lines.append(" * ") + comment_lines.append(" * @param in Input array (const double*). Layout:") + for i, sym in enumerate(flow.inputs): + comment_lines.append(f" * - in[{i}]: {sym}") + comment_lines.append(" * ") + comment_lines.append(" * @param out Output array (double*). Layout:") + offset = 0 + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + comment_lines.append(f" * - out[{offset}..{offset + buf.size - 1}]: {s} (buffer, size={buf.size})") + offset += buf.size + else: + comment_lines.append(f" * - out[{offset}]: {s}") + offset += 1 + comment_lines.append(" */") + comment_block = "\n".join(comment_lines) + + # --- 函数体 --- + body_lines = [] + + # 解包输入 + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + body_lines.append(f" double {s} = in[{i}];") + body_lines.append("") + + # 声明缓冲区 + for buf in flow.local_buffers: + dtype = "double" if buf.dtype == "double" else buf.dtype + body_lines.append(f" {dtype} {buf.name}[{buf.size}];") + body_lines.append("") + + # 声明标量变量 + if scalar_vars: + for var in sorted(scalar_vars): + body_lines.append(f" double {var};") + body_lines.append("") + + # 生成 body + FlowCompiler._emit_body(flow.body, buffer_map, printer, body_lines, indent=1) + + # 输出映射(按累积偏移) + body_lines.append("") + body_lines.append(" // --- Output ---") + offset = 0 + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + for j in range(buf.size): + body_lines.append(f" out[{offset + j}] = {s}[{j}];") + offset += buf.size + else: + body_lines.append(f" out[{offset}] = {s};") + offset += 1 + + body = "\n".join(body_lines) + + # --- 函数签名 --- + prefix_macros = MathCompiler._cpp_cuda_compat_macros() + "\n" + if is_cuda: + func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" + else: + func_type = "FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{flow.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + + main_func = f"{comment_block}\n{signature} {{\n{body}\n}}" + + # 组装:宏定义(1次) + 子模型函数 + 主流程函数 + parts = [prefix_macros] + sub_sources + [main_func] + return "\n\n".join(parts) + + # ------------------------------------------------------------------------- + # Fortran Flow 代码生成 + # ------------------------------------------------------------------------- + @staticmethod + def _flow_to_fortran(flow: FlowModel, chunk_size=None, cse_strategy="auto"): + """生成 FlowModel 的 Fortran 源码""" + + # 1. 编译所有子模型 + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = MathCompiler.compile(sub_model, "fortran", + chunk_size=chunk_size, cse_strategy=cse_strategy) + sub_sources.append(sub_src) + + # 2. 构建查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + scalar_vars = set() + FlowCompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + # 收集 For 循环的 index 变量(Fortran 需声明为 integer) + for_indices = set() + FlowCompiler._collect_for_indices(flow.body, for_indices) + + # 3. 生成主流程 + printer = CachedPrinter(FEAFortranPrinter()) + + lines = [ + "! Generated by sympy_codegen.py. Do not edit.", + "!", + f"! Flow kernel: compute_{flow.name}", + "!", + f"subroutine compute_{flow.name}(in_vec, out_vec)", + " implicit none", + f" double precision, intent(in) :: in_vec({len(flow.inputs)})", + f" double precision, intent(out) :: out_vec({sum(buffer_map[str(o)].size if str(o) in buffer_map else 1 for o in flow.outputs)})", + ] + + # 解包输入 + input_vars = [] + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + input_vars.append(s) + + if input_vars: + lines.append(" ! --- Unpack inputs ---") + lines.append(f" double precision :: {', '.join(input_vars)}") + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_vec({i + 1})") + lines.append("") + + # 声明缓冲区 + if flow.local_buffers: + lines.append(" ! --- Local Buffers ---") + for buf in flow.local_buffers: + lines.append(f" double precision :: {buf.name}({buf.size})") + lines.append("") + + # 声明标量变量 + if scalar_vars: + lines.append(" ! --- Local Scalars ---") + lines.append(f" double precision :: {', '.join(sorted(scalar_vars))}") + lines.append("") + + # 声明 For 循环 index 变量 + if for_indices: + lines.append(" ! --- Loop Indices ---") + lines.append(f" integer :: {', '.join(sorted(for_indices))}") + lines.append("") + + # 生成 body + FlowCompiler._emit_body_fortran(flow.body, buffer_map, printer, lines, indent=1) + + # 输出映射 + lines.append("") + lines.append(" ! --- Output ---") + offset = 0 + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + buf = buffer_map[s] + for j in range(buf.size): + lines.append(f" out_vec({offset + j + 1}) = {s}({j + 1})") + offset += buf.size + else: + lines.append(f" out_vec({offset + 1}) = {s}") + offset += 1 + + lines.append(f"end subroutine compute_{flow.name}") + + # 组装 + parts = sub_sources + ["\n".join(lines)] + return "\n\n".join(parts) + + # ========================================================================= + # 辅助方法:标量变量收集 + # ========================================================================= + @staticmethod + def _collect_scalar_vars(body, buffer_map, scalar_vars): + """递归收集 body 中 Call 输出的标量变量名""" + for stmt in body: + if isinstance(stmt, Call): + for var_name in stmt.output_vars: + if var_name not in buffer_map: + scalar_vars.add(var_name) + elif isinstance(stmt, If): + FlowCompiler._collect_scalar_vars(stmt.then_body, buffer_map, scalar_vars) + FlowCompiler._collect_scalar_vars(stmt.else_body, buffer_map, scalar_vars) + elif isinstance(stmt, For): + FlowCompiler._collect_scalar_vars(stmt.body, buffer_map, scalar_vars) + + @staticmethod + def _collect_for_indices(body, indices): + """递归收集 For 语句的 index 变量名""" + for stmt in body: + if isinstance(stmt, For): + indices.add(str(stmt.index)) + FlowCompiler._collect_for_indices(stmt.body, indices) + elif isinstance(stmt, If): + FlowCompiler._collect_for_indices(stmt.then_body, indices) + FlowCompiler._collect_for_indices(stmt.else_body, indices) + + # ========================================================================= + # 辅助方法:C++/CUDA 语句生成 + # ========================================================================= + @staticmethod + def _emit_body(body, buffer_map, printer, lines, indent=1): + """递归生成 C++/CUDA 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = MathCompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str};") + + elif isinstance(stmt, BufferFill): + buf = buffer_map[stmt.target] + fill_val = stmt.value + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {fill_val};") + + elif isinstance(stmt, BufferCopy): + buf = buffer_map[stmt.target] + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] = {stmt.source}[_i];") + + elif isinstance(stmt, BufferAccum): + buf = buffer_map[stmt.target] + lines.append(f"{pad}for (int _i = 0; _i < {buf.size}; _i++) {stmt.target}[_i] += {stmt.source}[_i];") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + # 计算子模型的总输入/输出数量 + # input_exprs 中可能有标量引用或 SymPy 表达式 + # 当 input_exprs 元素是 str 且在 buffer_map 中时,需要逐元素填充 + lines.append(f"{pad}// Call: {model_name}") + # 计算总输入元素数 + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + # 计算总输出元素数(buffer 展开为逐元素) + total_out = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + total_out += buffer_map[var_name].size + else: + total_out += 1 + + lines.append(f"{pad}{{") + lines.append(f"{pad} double _call_in[{total_in}];") + # 填充输入 + in_idx = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + buf = buffer_map[e] + for j in range(buf.size): + lines.append(f"{pad} _call_in[{in_idx}] = {e}[{j}];") + in_idx += 1 + else: + lines.append(f"{pad} _call_in[{in_idx}] = {MathCompiler._print_expr(e, printer)};") + in_idx += 1 + + # 输出:构造统一的 _call_out 数组,然后按累积偏移分发 + lines.append(f"{pad} double _call_out[{total_out}];") + lines.append(f"{pad} compute_{model_name}(_call_in, _call_out);") + + # 按累积偏移分发输出 + offset = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + buf = buffer_map[var_name] + for j in range(buf.size): + lines.append(f"{pad} {var_name}[{j}] = _call_out[{offset + j}];") + offset += buf.size + else: + lines.append(f"{pad} {var_name} = _call_out[{offset}];") + offset += 1 + + lines.append(f"{pad}}}") + + elif isinstance(stmt, If): + cond_str = MathCompiler._print_expr(stmt.cond, printer) + lines.append(f"{pad}if ({cond_str}) {{") + FlowCompiler._emit_body(stmt.then_body, buffer_map, printer, lines, indent + 1) + if stmt.else_body: + lines.append(f"{pad}}} else {{") + FlowCompiler._emit_body(stmt.else_body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}}}") + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + # 展开循环 + for i in range(int(start), int(end)): + lines.append(f"{pad}// Unrolled iteration {idx} = {i}") + # 替换 body 中引用 index 的表达式 + sub_body = MathCompiler._substitute_index(stmt.body, stmt.index, i) + FlowCompiler._emit_body(sub_body, buffer_map, printer, lines, indent) + else: + lines.append(f"{pad}for (int {idx} = {start}; {idx} < {end}; {idx}++) {{") + FlowCompiler._emit_body(stmt.body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}}}") + + # ========================================================================= + # 辅助方法:Fortran 语句生成 + # ========================================================================= + @staticmethod + def _emit_body_fortran(body, buffer_map, printer, lines, indent=1): + """递归生成 Fortran 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = MathCompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str}") + + elif isinstance(stmt, BufferFill): + buf = buffer_map[stmt.target] + fill_val = stmt.value + if fill_val == 0.0: + lines.append(f"{pad}{stmt.target}(:) = 0.0d0") + else: + lines.append(f"{pad}{stmt.target}(:) = {fill_val}d0") + + elif isinstance(stmt, BufferCopy): + lines.append(f"{pad}{stmt.target}(:) = {stmt.source}(:)") + + elif isinstance(stmt, BufferAccum): + buf = buffer_map[stmt.target] + lines.append(f"{pad}{stmt.target}(:) = {stmt.target}(:) + {stmt.source}(:)") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + + # 计算总输入元素数(buffer 展开为逐元素) + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + lines.append(f"{pad}! Call: {model_name}") + lines.append(f"{pad}block") + lines.append(f"{pad} double precision :: _call_in({total_in})") + # 计算总输出元素数(buffer 展开为逐元素) + total_out = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + total_out += buffer_map[var_name].size + else: + total_out += 1 + lines.append(f"{pad} double precision :: _call_out({total_out})") + + in_idx = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + buf = buffer_map[e] + for j in range(buf.size): + lines.append(f"{pad} _call_in({in_idx + 1}) = {e}({j + 1})") + in_idx += 1 + else: + lines.append(f"{pad} _call_in({in_idx + 1}) = {MathCompiler._print_expr(e, printer)}") + in_idx += 1 + + lines.append(f"{pad} call compute_{model_name}(_call_in, _call_out)") + + # 按累积偏移分发输出 + offset = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + buf = buffer_map[var_name] + for j in range(buf.size): + lines.append(f"{pad} {var_name}({j + 1}) = _call_out({offset + j + 1})") + offset += buf.size + else: + lines.append(f"{pad} {var_name} = _call_out({offset + 1})") + offset += 1 + + lines.append(f"{pad}end block") + + elif isinstance(stmt, If): + cond_str = MathCompiler._print_expr(stmt.cond, printer) + lines.append(f"{pad}if ({cond_str}) then") + FlowCompiler._emit_body_fortran(stmt.then_body, buffer_map, printer, lines, indent + 1) + if stmt.else_body: + lines.append(f"{pad}else") + FlowCompiler._emit_body_fortran(stmt.else_body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}end if") + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + for i in range(int(start), int(end)): + lines.append(f"{pad}! Unrolled iteration {idx} = {i}") + sub_body = MathCompiler._substitute_index(stmt.body, stmt.index, i) + FlowCompiler._emit_body_fortran(sub_body, buffer_map, printer, lines, indent) + else: + lines.append(f"{pad}do {idx} = {start}, {end} - 1") + FlowCompiler._emit_body_fortran(stmt.body, buffer_map, printer, lines, indent + 1) + lines.append(f"{pad}end do") + + # ========================================================================= + # JAX Flow 代码生成 + # ========================================================================= + @staticmethod + def _flow_to_jax(flow: FlowModel, chunk_size=None, cse_strategy="auto"): + """ + 生成 FlowModel 的 JAX 源码。 + + JAX 是纯函数式的,生成策略: + - 缓冲区 → jnp.zeros / jnp.array + - BufferAccum → buf = buf.at[i].add(src[i]) 或 buf = buf + src + - For(unroll=False) → jax.lax.fori_loop + - For(unroll=True) → Python 展开循环 + - If → jax.lax.cond (then/else 必须返回相同结构的 tuple) + - Call → 调用子模型函数 + """ + + # 1. 编译所有子模型 + sub_sources = [] + for name, sub_model in flow.submodels.items(): + sub_src = MathCompiler.compile(sub_model, "jax", + chunk_size=chunk_size, cse_strategy=cse_strategy) + sub_sources.append(sub_src) + + # 2. 构建查找表 + buffer_map = {b.name: b for b in flow.local_buffers} + scalar_vars = set() + FlowCompiler._collect_scalar_vars(flow.body, buffer_map, scalar_vars) + for_indices = set() + FlowCompiler._collect_for_indices(flow.body, for_indices) + + # 3. 生成主流程 + printer = CachedPrinter(JaxPrinter()) + + lines = [ + '"""Generated by sympy_codegen.py. Do not edit."""', + "import jax", + "import jax.numpy as jnp", + "", + "", + f"def compute_{flow.name}(in_flat):", + f' """', + f' Flow kernel: {flow.name}', + f' ', + f' Args:', + f' in_flat: Flattened input array, size {len(flow.inputs)}', + f' ', + f' Returns:', + f' Flattened output array', + f' """', + ] + + # 解包输入 + for i, sym in enumerate(flow.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_flat[{i}]") + lines.append("") + + # 初始化缓冲区 + for buf in flow.local_buffers: + if buf.dtype == "double": + lines.append(f" {buf.name} = jnp.zeros({buf.size})") + else: + lines.append(f" {buf.name} = jnp.zeros({buf.size}, dtype=jnp.{buf.dtype})") + lines.append("") + + # 声明标量变量初始值(用于 JAX 的函数式风格) + # JAX 中标量不需要预声明,在赋值时绑定即可 + + # 生成 body + FlowCompiler._emit_body_jax(flow.body, buffer_map, printer, lines, indent=1, for_indices=for_indices) + + # 输出映射 + lines.append("") + lines.append(" # --- Output ---") + out_parts = [] + for out_name in flow.outputs: + s = str(out_name) + if s in buffer_map: + out_parts.append(s) + else: + out_parts.append(s) + + if len(out_parts) == 1: + lines.append(f" return {out_parts[0]}") + else: + lines.append(f" return jnp.concatenate([{', '.join(out_parts)}])") + + src = "\n".join(sub_sources) + "\n\n\n" + "\n".join(lines) + src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") + return src + + # ------------------------------------------------------------------------- + # JAX 语句生成 + # ------------------------------------------------------------------------- + @staticmethod + def _emit_body_jax(body, buffer_map, printer, lines, indent=1, for_indices=None): + """递归生成 JAX 语句""" + pad = " " * indent + for stmt in body: + if isinstance(stmt, Assign): + target = str(stmt.target) + expr_str = MathCompiler._print_expr(stmt.expr, printer) + lines.append(f"{pad}{target} = {expr_str}") + + elif isinstance(stmt, BufferFill): + buf = buffer_map[stmt.target] + fill_val = stmt.value + if fill_val == 0.0: + lines.append(f"{pad}{stmt.target} = jnp.zeros({buf.size})") + else: + lines.append(f"{pad}{stmt.target} = jnp.full({buf.size}, {fill_val})") + + elif isinstance(stmt, BufferCopy): + lines.append(f"{pad}{stmt.target} = {stmt.source}.copy()") + + elif isinstance(stmt, BufferAccum): + buf_target = buffer_map[stmt.target] + buf_source = buffer_map[stmt.source] + if buf_target.size == buf_source.size: + # 同尺寸:向量加法 + lines.append(f"{pad}{stmt.target} = {stmt.target} + {stmt.source}") + else: + # 不同尺寸:逐元素 at[].add()(不太常见,但保留安全路径) + lines.append(f"{pad}{stmt.target} = {stmt.target}.at[:{buf_source.size}].add({stmt.source})") + + elif isinstance(stmt, Call): + model_name = stmt.model_name + + # 计算总输入元素数 + total_in = 0 + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + total_in += buffer_map[e].size + else: + total_in += 1 + + # 构建输入 + in_parts = [] + for e in stmt.input_exprs: + if isinstance(e, str) and e in buffer_map: + in_parts.append(e) + else: + in_parts.append(MathCompiler._print_expr(e, printer)) + + lines.append(f"{pad}# Call: {model_name}") + if len(in_parts) == 1: + # 单个输入,可能是标量或数组 + lines.append(f"{pad}_call_in = jnp.array([{in_parts[0]}]) if not isinstance({in_parts[0]}, jnp.ndarray) else {in_parts[0]}.reshape(-1)") + else: + # 多个输入,拼接为 1D 数组 + items = ", ".join(in_parts) + lines.append(f"{pad}_call_in = jnp.concatenate([jnp.atleast_1d(jnp.asarray(x)) for x in [{items}]])") + + lines.append(f"{pad}_call_out = compute_{model_name}(_call_in)") + + # 分发输出 + offset = 0 + for var_name in stmt.output_vars: + if var_name in buffer_map: + buf = buffer_map[var_name] + lines.append(f"{pad}{var_name} = _call_out[{offset}:{offset + buf.size}]") + offset += buf.size + else: + lines.append(f"{pad}{var_name} = _call_out[{offset}]") + offset += 1 + + elif isinstance(stmt, If): + # JAX: jax.lax.cond(pred, true_fun, false_fun, operand) + # 但生成 jax.lax.cond 的完整函数定义太复杂, + # 如果 body 中只有简单语句,用 jnp.where 更实用 + # 否则用 jax.lax.cond + + # 策略:检测 then_body 和 else_body 的复杂度 + # 如果两者都是纯赋值/BufferAccum,用 jnp.where 系列 + # 如果有 Call/For/嵌套 If,用 jax.lax.cond + + is_simple = FlowCompiler._is_simple_jax_if(stmt) + + if is_simple: + # 简单 If:逐语句生成 jnp.where 版本 + FlowCompiler._emit_simple_if_jax(stmt, buffer_map, printer, lines, indent) + else: + # 复杂 If:用 jax.lax.cond + FlowCompiler._emit_cond_if_jax(stmt, buffer_map, printer, lines, indent) + + elif isinstance(stmt, For): + idx = str(stmt.index) + start = stmt.start + end = stmt.end + + if stmt.unroll: + # 展开循环 + for i in range(int(start), int(end)): + lines.append(f"{pad}# Unrolled iteration {idx} = {i}") + sub_body = MathCompiler._substitute_index(stmt.body, stmt.index, i) + FlowCompiler._emit_body_jax(sub_body, buffer_map, printer, lines, indent, for_indices) + else: + # jax.lax.fori_loop + # 需要把循环体封装为一个函数 + # 携带状态 = 所有缓冲区 + 循环体内修改的标量 + carried = FlowCompiler._collect_carried_vars(stmt.body, buffer_map) + + if carried: + carry_names = sorted(carried) + carry_tuple = ", ".join(carry_names) + + # 生成循环体函数 + lines.append(f"{pad}def _for_body_{idx}({idx}, _carry):") + for i, name in enumerate(carry_names): + lines.append(f"{pad} {name} = _carry[{i}]") + + # 生成循环体 + FlowCompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent + 1, for_indices) + + # 返回 carry + ret_parts = ", ".join(carry_names) + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # 调用 fori_loop + init_parts = ", ".join(carry_names) + if len(carry_names) == 1: + # 单元素 carry: fori_loop 返回 (val,), 需要 [0] 解包 + lines.append(f"{pad}{carry_tuple}, = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") + else: + lines.append(f"{pad}{carry_tuple} = jax.lax.fori_loop({start}, {end}, _for_body_{idx}, ({init_parts},))") + else: + # 无携带状态,循环体无副作用,只需调用一次 + lines.append(f"{pad}# For loop with no carried state — body executed once") + FlowCompiler._emit_body_jax(stmt.body, buffer_map, printer, lines, indent, for_indices) + + # ------------------------------------------------------------------------- + # JAX If 辅助方法 + # ------------------------------------------------------------------------- + @staticmethod + def _is_simple_jax_if(if_stmt): + """判断 If 语句是否足够简单,可以用 jnp.where 实现""" + def _is_simple_body(body): + for s in body: + if isinstance(s, (Call, For, If)): + return False + if isinstance(s, BufferAccum): + return False # BufferAccum 需要 += 语义 + return True + + return _is_simple_body(if_stmt.then_body) and _is_simple_body(if_stmt.else_body) + + @staticmethod + def _emit_simple_if_jax(if_stmt, buffer_map, printer, lines, indent): + """用 jnp.where 生成简单 If""" + pad = " " * indent + cond_str = MathCompiler._print_expr(if_stmt.cond, printer) + + # 先生成 then_body 的赋值 + for s in if_stmt.then_body: + if isinstance(s, Assign): + target = str(s.target) + then_val = MathCompiler._print_expr(s.expr, printer) + # 从 else_body 找同名赋值,或用原值 + else_val = target # 默认:不变 + for es in if_stmt.else_body: + if isinstance(es, Assign) and str(es.target) == target: + else_val = MathCompiler._print_expr(es.expr, printer) + break + if target in buffer_map: + lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") + else: + lines.append(f"{pad}{target} = jnp.where({cond_str}, {then_val}, {else_val})") + + elif isinstance(s, BufferFill): + buf = buffer_map[s.target] + fill_val = s.value + if fill_val == 0.0: + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.zeros({buf.size}), {s.target})") + else: + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, jnp.full({buf.size}, {fill_val}), {s.target})") + + elif isinstance(s, BufferCopy): + lines.append(f"{pad}{s.target} = jnp.where({cond_str}, {s.source}.copy(), {s.target})") + + # else_body 中独有的赋值 + then_targets = {str(s.target) for s in if_stmt.then_body if isinstance(s, (Assign, BufferFill, BufferCopy))} + for s in if_stmt.else_body: + if isinstance(s, Assign) and str(s.target) not in then_targets: + target = str(s.target) + else_val = MathCompiler._print_expr(s.expr, printer) + lines.append(f"{pad}{target} = jnp.where({cond_str}, {target}, {else_val})") + + @staticmethod + def _emit_cond_if_jax(if_stmt, buffer_map, printer, lines, indent): + """用 jax.lax.cond 生成复杂 If""" + pad = " " * indent + cond_str = MathCompiler._print_expr(if_stmt.cond, printer) + + # 收集 then/else 修改的变量 + then_carried = FlowCompiler._collect_carried_vars(if_stmt.then_body, buffer_map) + else_carried = FlowCompiler._collect_carried_vars(if_stmt.else_body, buffer_map) + carried = sorted(then_carried | else_carried) + + if not carried: + # 无副作用,直接生成 then_body(条件满足时执行) + # 但 JAX 的 cond 需要两边都有返回值 + # 简化:直接生成 then_body + FlowCompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent) + return + + carry_tuple = ", ".join(carried) + + # then 函数 + lines.append(f"{pad}def _if_true(_carry):") + for i, name in enumerate(carried): + lines.append(f"{pad} {name} = _carry[{i}]") + FlowCompiler._emit_body_jax(if_stmt.then_body, buffer_map, printer, lines, indent + 1) + ret_parts = ", ".join(carried) + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # else 函数 + lines.append(f"{pad}def _if_false(_carry):") + for i, name in enumerate(carried): + lines.append(f"{pad} {name} = _carry[{i}]") + if if_stmt.else_body: + FlowCompiler._emit_body_jax(if_stmt.else_body, buffer_map, printer, lines, indent + 1) + # else 必须和 then 返回相同结构 + lines.append(f"{pad} return ({ret_parts},)") + lines.append("") + + # 调用 jax.lax.cond + init_parts = ", ".join(carried) + if len(carried) == 1: + lines.append(f"{pad}{carry_tuple}, = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") + else: + lines.append(f"{pad}{carry_tuple} = jax.lax.cond({cond_str}, _if_true, _if_false, ({init_parts},))") + + # ------------------------------------------------------------------------- + # JAX 携带变量收集 + # ------------------------------------------------------------------------- + @staticmethod + def _collect_carried_vars(body, buffer_map): + """收集 body 中被修改的变量名(需要作为 fori_loop/cond 的 carry)""" + carried = set() + for stmt in body: + if isinstance(stmt, Assign): + carried.add(str(stmt.target)) + elif isinstance(stmt, BufferFill): + carried.add(stmt.target) + elif isinstance(stmt, BufferCopy): + carried.add(stmt.target) + elif isinstance(stmt, BufferAccum): + carried.add(stmt.target) + elif isinstance(stmt, Call): + for var_name in stmt.output_vars: + carried.add(var_name) + elif isinstance(stmt, If): + then_c = FlowCompiler._collect_carried_vars(stmt.then_body, buffer_map) + else_c = FlowCompiler._collect_carried_vars(stmt.else_body, buffer_map) + carried |= then_c | else_c + elif isinstance(stmt, For): + carried |= FlowCompiler._collect_carried_vars(stmt.body, buffer_map) + return carried diff --git a/codegen/math_compiler.py b/codegen/math_compiler.py new file mode 100644 index 0000000..970e70a --- /dev/null +++ b/codegen/math_compiler.py @@ -0,0 +1,656 @@ +import re + +import sympy as sp +from sympy.core.relational import Relational +from sympy.printing.numpy import JaxPrinter + +from codegen.model import ( + MathModel, Assign, BufferFill, BufferCopy, BufferAccum, + Call, If, For, BufferRef, +) +from codegen.lowered import LoweredChunk, LoweredModel, CachedPrinter +from codegen.printer import FEACodePrinter, FEAFortranPrinter +from definitions.abc import Element + + +class MathCompiler: + # ========================================================================= + # 公共 Lower 阶段:将 MathModel 转换为 LoweredModel,执行 CSE + # ========================================================================= + @staticmethod + def lower_model(model: MathModel, chunk_size: int) -> LoweredModel: + """执行 CSE lowering,返回可被多个后端共享的 LoweredModel""" + outputs = model.outputs + chunks = [] + + for start in range(0, len(outputs), chunk_size): + chunk_index = start // chunk_size + chunk = outputs[start:start + chunk_size] + sub_exprs, simplified_chunk = sp.cse( + chunk, + symbols=sp.numbered_symbols(f"v_{chunk_index}_") + ) + chunks.append( + LoweredChunk( + chunk_index=chunk_index, + start_index=start, + sub_exprs=sub_exprs, + simplified_outputs=simplified_chunk + ) + ) + + return LoweredModel(model.name, chunk_size, chunks, + external_calls=model.external_calls, + external_ops=model.external_ops) + + # ========================================================================= + # Chunk Size 策略:根据模型规模和目标平台决定 chunk size + # ========================================================================= + @staticmethod + def resolve_chunk_size(model: MathModel, target: str, user_chunk_size=None, strategy="auto") -> int: + """ + 决定 CSE chunk size 的策略。 + + Args: + model: 数学模型 + target: 目标平台 (jax/cpp/cuda/fortran等) + user_chunk_size: 用户通过 CLI 指定的 chunk size + strategy: 策略模式 ("auto" 或 "fixed") + + Returns: + 最终的 chunk size + """ + if user_chunk_size is not None: + return user_chunk_size + + nout = len(model.outputs) + target = target.lower() + + # fixed 模式:使用各后端的固定默认值 + if strategy == "fixed": + if target == "jax": + return 50 + if target in ("cpp", "c++", "cuda", "fortran"): + return 24 + return 24 + + # auto 模式:根据输出规模自动调整 + if strategy == "auto": + if target == "jax": + if nout <= 64: + return 64 + elif nout <= 256: + return 48 + else: + return 32 + + # cpp/cuda/fortran 的自适应策略 + if nout <= 32: + return 32 + elif nout <= 128: + return 24 + elif nout <= 512: + return 16 + else: + return 8 + + raise ValueError(f"Unknown strategy: {strategy}") + + # ========================================================================= + # C++/CUDA 兼容性宏:跨平台支持 GCC/Clang/MSVC/CUDA + # ========================================================================= + @staticmethod + def _cpp_cuda_compat_macros() -> str: + """返回统一的 C++/CUDA 跨平台兼容性宏定义""" + return r""" +#if defined(__CUDACC__) + #define FEA_DEVICE __device__ + #define FEA_HOST __host__ + #define FEA_HOST_DEVICE __host__ __device__ + #define FEA_RESTRICT __restrict__ +#else + #define FEA_DEVICE + #define FEA_HOST + #define FEA_HOST_DEVICE + #if defined(_WIN32) || defined(_WIN64) + #if defined(_MSC_VER) + #define FEA_RESTRICT __restrict + #else + #define FEA_RESTRICT __restrict__ + #endif + #else + #if defined(__GNUC__) || defined(__clang__) + #define FEA_RESTRICT __restrict__ + #else + #define FEA_RESTRICT + #endif + #endif +#endif + +#if defined(_MSC_VER) + #define FEA_ALWAYS_INLINE __forceinline +#elif defined(__GNUC__) || defined(__clang__) + #define FEA_ALWAYS_INLINE inline __attribute__((always_inline)) +#else + #define FEA_ALWAYS_INLINE inline +#endif +""" + + # ========================================================================= + # 核心编译接口 + # ========================================================================= + @staticmethod + def compile(model: MathModel, target: str, chunk_size=None, cse_strategy="auto", lowered=None): + """ + 核心分发器:输入 MathModel + target,输出 cpp/cuda/jax/fortran 源码字符串。 + + Args: + model: 数学模型 + target: 目标平台 ('jax', 'cpp', 'cuda', 'fortran') + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + lowered: 预先 lowered 的结果 (可选,用于多后端共享) + """ + target = target.lower() + if target == 'jax': + return MathCompiler._to_jax(model, lowered=lowered, chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target in ['cpp', 'c++']: + return MathCompiler._to_source(model, is_cuda=False, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'cuda': + return MathCompiler._to_source(model, is_cuda=True, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + elif target == 'fortran': + return MathCompiler._to_fortran(model, lowered=lowered, + chunk_size=chunk_size, cse_strategy=cse_strategy) + else: + raise ValueError(f"Unknown target: {target}") + + @staticmethod + def compile_all(model: MathModel, chunk_size=None, cse_strategy="auto", test=False, + task=None, model_name=None): + """ + 一次性生成 jax/cpp/cuda/fortran 四种目标源码。 + + 统一管理 lower 行为: + - 如果所有 target 使用相同的 chunk size,共享一份 lowered + - 如果 JAX 和 cpp/cuda/fortran 使用不同的 chunk size,分别生成 jax_lowered 和 shared_lowered + + Args: + model: 数学模型 + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + test: 是否同时生成测试资产(wrapper、test_driver、build script) + task: CLI 任务类型 ('constitutive', 'stiffness', 'mass', 'custom'),用于 test_driver 重新加载模型 + model_name: 模型/材料/单元名称,用于 test_driver 重新加载模型 + + Returns: + dict: {'jax': code, 'cpp': code, 'cuda': code, 'fortran': code, + 'cpp_wrapper': str, 'f90_wrapper': str, 'test_driver': str, + 'build_sh': str, 'build_bat': str} (后5项仅在 test=True 时存在) + """ + from ci_test.wrappers import generate_cpp_main, generate_f90_main + from ci_test.test_driver_template import generate_test_driver + from ci_test.build_script_generator import generate_build_sh, generate_build_bat + + # 决定各 target 的 chunk size + cpp_chunk = MathCompiler.resolve_chunk_size(model, "cpp", chunk_size, cse_strategy) + jax_chunk = MathCompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) + + # 生成 shared lowered 给 cpp/cuda/fortran + shared_lowered = MathCompiler.lower_model(model, cpp_chunk) + + # 决定 JAX 是否共享 lowered + if jax_chunk == cpp_chunk: + jax_lowered = shared_lowered + else: + jax_lowered = MathCompiler.lower_model(model, jax_chunk) + + result = { + "jax": MathCompiler._to_jax(model, lowered=jax_lowered, chunk_size=jax_chunk, cse_strategy=cse_strategy), + "cpp": MathCompiler._to_source(model, is_cuda=False, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + "cuda": MathCompiler._to_source(model, is_cuda=True, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + "fortran": MathCompiler._to_fortran(model, lowered=shared_lowered, chunk_size=cpp_chunk, cse_strategy=cse_strategy), + } + + if test: + result["cpp_wrapper"] = generate_cpp_main(model) + result["f90_wrapper"] = generate_f90_main(model) + result["test_driver"] = generate_test_driver(model, task=task, model_name=model_name) + result["build_sh"] = generate_build_sh(model) + result["build_bat"] = generate_build_bat(model) + + return result + + @staticmethod + def _to_jax(model, lowered=None, chunk_size=None, cse_strategy="auto"): + """ + 生成 JAX 源码(.py),采用分块 CSE 优化。 + + Args: + model: 数学模型 + lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + """ + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = MathCompiler.resolve_chunk_size(model, "jax", chunk_size, cse_strategy) + lowered = MathCompiler.lower_model(model, chunk_size) + + lines = [ + '"""Generated by sympy_codegen.py. Do not edit."""', + "import jax.numpy as jnp", + "", + "", + f"def compute_{model.name}(in_flat):", + f' """', + f' Compute the {model.name} kernel.', + f' ', + f' Args:', + f' in_flat: Flattened input array, size {len(model.inputs)}', + f' ', + f' Returns:', + f' Flattened output array, size {len(model.outputs)}', + f' ', + f' Input layout:', + ] + + # 添加输入信息 + for i, name in enumerate(model.input_names): + lines.append(f" ' - in_flat[{i}]: {name}") + + lines.append(f" '") + lines.append(f" ' Output layout:") + + # 添加输出信息 + for i, name in enumerate(model.output_names): + lines.append(f" ' - out[{i}]: {name}") + + lines.append(f' """') + + # Unpack inputs IF they are valid identifiers (e.g. xi, c0) + # If they are like "in[0]", we'll handle them via string replacement later + for i, sym in enumerate(model.inputs): + s = str(sym) + is_ident = s.isidentifier() + # print(f"DEBUG: sym={s}, is_ident={is_ident}") + if is_ident: + lines.append(f" {s} = in_flat[{i}]") + + lines.append("") + + printer = CachedPrinter(JaxPrinter()) + all_simplified_outputs = [] + + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + if op.jax_func is None: + raise ValueError( + f"External operator '{call.op_name}' has no JAX implementation. " + f"Cannot generate JAX code for model '{model.name}'." + ) + lines.append(f" # --- External Operator: {call.op_name} ---") + in_parts = ", ".join(printer.doprint(e) for e in call.input_exprs) + lines.append(f" {call.prefix}_in = jnp.array([{in_parts}])") + lines.append(f" {call.prefix}_out = {op.jax_func}({call.prefix}_in)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out[{i}]") + lines.append("") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + for var, expr in chunk.sub_exprs: + lines.append(f" {var} = {printer.doprint(expr)}") + + all_simplified_outputs.extend(chunk.simplified_outputs) + + lines.append("") + lines.append(" # --- Output ---") + out_parts = [printer.doprint(e) for e in all_simplified_outputs] + lines.append(f" return ({','.join(out_parts)})") + + src = "\n".join(lines) + # Final cleanup for JAX and handle C-style inputs + src = src.replace("jax.numpy.", "jnp.").replace("in[", "in_flat[") + return src + + @staticmethod + def _to_source(model, is_cuda=False, lowered=None, chunk_size=None, cse_strategy="auto"): + """ + 生成 C++/CUDA 源码,采用分块 CSE 优化及算子化增强。 + + Args: + model: 数学模型 + is_cuda: 是否为 CUDA 目标 + lowered: 预先 lowered 的 LoweredModel (可选,用于多后端共享) + chunk_size: 用户指定的 chunk size (可选) + cse_strategy: CSE 策略 ('auto' 或 'fixed') + """ + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = MathCompiler.resolve_chunk_size(model, "cuda" if is_cuda else "cpp", + chunk_size, cse_strategy) + lowered = MathCompiler.lower_model(model, chunk_size) + + # --- Generate Comments --- + comment_lines = ["/**"] + comment_lines.append(f" * @brief Computes the {model.name} kernel.") + if model.is_operator: + comment_lines.append(" * @note This is an optimized operator kernel.") + comment_lines.append(" * ") + comment_lines.append(" * @param in Input array (const double*). Layout:") + + for i, name in enumerate(model.input_names): + comment_lines.append(f" * - in[{i}]: {name}") + + comment_lines.append(" * ") + comment_lines.append(" * @param out Output array (double*). Layout:") + + # 列出每个输出的详细信息 + for i, name in enumerate(model.output_names): + comment_lines.append(f" * - out[{i}]: {name}") + + comment_lines.append(" */") + comment_block = "\n".join(comment_lines) + + # --- Generate Function Body --- + body_lines = [] + + # 解包输入变量 + for i, sym in enumerate(model.inputs): + s = str(sym) + # 检查是否是合法标识符(如 coord_2_3),如果是则解包 + if s.isidentifier(): + body_lines.append(f" double {s} = in[{i}];") + + body_lines.append("") + + # 初始化带缓存的 Printer + printer = CachedPrinter(FEACodePrinter()) + + # 外部算子调用 + if lowered.external_calls: + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + body_lines.append(f" // --- External Operator: {call.op_name} ---") + body_lines.append(f" double {call.prefix}_in[{op.n_inputs}];") + body_lines.append(f" double {call.prefix}_out[{op.n_outputs}];") + + for i, expr in enumerate(call.input_exprs): + body_lines.append(f" {call.prefix}_in[{i}] = {printer.doprint(expr)};") + + body_lines.append(f" {op.cpp_func}({call.prefix}_in, {call.prefix}_out);") + + for i, sym in enumerate(call.output_symbols): + body_lines.append(f" double {sym} = {call.prefix}_out[{i}];") + + body_lines.append("") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + body_lines.append(f"\n // --- Chunk {chunk.chunk_index} ---") + + for var, expr in chunk.sub_exprs: + body_lines.append(f" double {var} = {printer.doprint(expr)};") + + for j, out_expr in enumerate(chunk.simplified_outputs): + body_lines.append(f" out[{chunk.start_index + j}] = {printer.doprint(out_expr)};") + + body = "\n".join(body_lines) + + # 统一使用兼容宏体系 + prefix = MathCompiler._cpp_cuda_compat_macros() + "\n" + + if is_cuda: + # CUDA 使用 FEA_DEVICE 宏 + func_type = "FEA_DEVICE FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + else: + # C++ 使用 FEA_ALWAYS_INLINE 宏 + func_type = "FEA_ALWAYS_INLINE void" + signature = f"{func_type} compute_{model.name}(const double* FEA_RESTRICT in, double* FEA_RESTRICT out)" + + return f"{prefix}{comment_block}\n{signature} {{ \n{body}\n}}" + + + + @staticmethod + def _to_fortran(model, lowered=None, chunk_size=None, cse_strategy="auto"): + """生成 Fortran 源码,支持分块 CSE 优化。声明和赋值必须分离。""" + # 如果没有提供 lowered 结果,则自行 lower + if lowered is None: + chunk_size = MathCompiler.resolve_chunk_size(model, "fortran", chunk_size, cse_strategy) + lowered = MathCompiler.lower_model(model, chunk_size) + + printer = CachedPrinter(FEAFortranPrinter()) + + def _fortran_declare(type_decl, vars_list, indent=" "): + """Generate Fortran declaration with line continuation if exceeding 120 chars. + Fortran free-format limit is 132 chars; we use 120 for safety margin. + Continuation uses '&' at end of line and '&' at start of continuation. + The comma separator must appear at the end of the line (before &) + so that the continuation line can start cleanly with the next variable. + """ + if not vars_list: + return [] + max_len = 120 + prefix = f"{indent}{type_decl} :: " + # Try single line first + single_line = prefix + ", ".join(vars_list) + if len(single_line) <= max_len: + return [single_line] + # Split across multiple lines with continuation + # Strategy: each line ends with ", &" (comma before ampersand) + # and continuation lines start with "& " then the next variable + result_lines = [] + current = prefix + first = True + for v in vars_list: + # Check if adding this variable (with separator) would exceed limit + if first: + candidate = current + v + else: + candidate = current + ", " + v + if len(candidate) + 2 > max_len and not first: + # End current line with comma + ampersand for continuation + result_lines.append(current + ", &") + current = f"{indent}& {v}" + first = False + else: + current = candidate + first = False + result_lines.append(current) + return result_lines + + lines = [ + "! Generated by sympy_codegen.py. Do not edit.", + "!", + f"! Subroutine: compute_{model.name}", + "!", + "! Input array layout (in_vec):", + ] + + # 添加输入信息 + for i, name in enumerate(model.input_names): + lines.append(f"! in_vec({i + 1}): {name}") + + lines.append("!") + lines.append("! Output array layout (out_vec):") + + # 添加输出信息 + for i, name in enumerate(model.output_names): + lines.append(f"! out_vec({i + 1}): {name}") + + lines.extend([ + "!", + f"subroutine compute_{model.name}(in_vec, out_vec)", + " implicit none", + f" double precision, intent(in) :: in_vec({len(model.inputs)})", + f" double precision, intent(out) :: out_vec({len(model.outputs)})", + " ! --- Unpack inputs ---", + ]) + + # Unpack input array to named variables + input_vars = [] + for i, sym in enumerate(model.inputs): + s = str(sym) + if s.isidentifier(): + input_vars.append(s) + + # First, declare all input variables (with line continuation if needed) + if input_vars: + lines.extend(_fortran_declare("double precision", input_vars, " ")) + + # Then assign values + for i, sym in enumerate(model.inputs): + s = str(sym) + if s.isidentifier(): + lines.append(f" {s} = in_vec({i + 1})") + + # 外部算子调用 + if lowered.external_calls: + lines.append("") + lines.append(" ! --- External Operator Calls ---") + # 先声明所有外部算子相关的变量 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.extend(_fortran_declare("double precision", + [f"{call.prefix}_in({op.n_inputs})", f"{call.prefix}_out({op.n_outputs})"], " ")) + out_var_names = [str(sym) for sym in call.output_symbols] + lines.extend(_fortran_declare("double precision", out_var_names, " ")) + + # 然后赋值和调用 + for call in lowered.external_calls: + op = lowered.external_ops[call.op_name] + lines.append(f" ! External Operator: {call.op_name}") + for i, expr in enumerate(call.input_exprs): + lines.append(f" {call.prefix}_in({i + 1}) = {printer.doprint(expr)}") + lines.append(f" call {op.fortran_func}({call.prefix}_in, {call.prefix}_out)") + for i, sym in enumerate(call.output_symbols): + lines.append(f" {sym} = {call.prefix}_out({i + 1})") + + lines.append("") + + lines.append(" ! --- Local Variables for CSE ---") + + # 使用 lowered 结果 + for chunk in lowered.chunks: + lines.append(f" ! Chunk {chunk.chunk_index}") + if chunk.sub_exprs: + lines.append(" block") + + # Separate variables by type: logical for comparisons, double precision otherwise + dp_vars = [] + log_vars = [] + for var, expr in chunk.sub_exprs: + if isinstance(expr, Relational): + log_vars.append(str(var)) + else: + dp_vars.append(str(var)) + + if dp_vars: + lines.extend(_fortran_declare("double precision", dp_vars, " ")) + if log_vars: + lines.extend(_fortran_declare("logical", log_vars, " ")) + + # Then assign values + for var, expr in chunk.sub_exprs: + lines.append(f" {var} = {printer.doprint(expr)}") + + for j, out_expr in enumerate(chunk.simplified_outputs): + # Fortran arrays are 1-based. + lines.append(f" out_vec({chunk.start_index + j + 1}) = {printer.doprint(out_expr)}") + + if chunk.sub_exprs: + lines.append(" end block") + + lines.append(f"end subroutine compute_{model.name}") + src = "\n".join(lines) + # Replace C-style array access in[i] with Fortran 1-based in_vec(i+1) + # This handles SymPy symbols like in[0], in[1] that are not valid identifiers + def _replace_in_array(m): + idx = int(m.group(1)) + return f"in_vec({idx + 1})" + src = re.sub(r'in\[(\d+)\]', _replace_in_array, src) + return src + + @staticmethod + def compile_element(element: Element, target: str, chunk_size=None, cse_strategy="auto"): + """ + Special compiler for Elements: supports both single-kernel and operator-based generation. + """ + operators = element.get_stiffness_operators() + if operators: + # Generate multiple operator kernels + generated = {} + for op_model in operators: + generated[op_model.name] = MathCompiler.compile(op_model, target, + chunk_size=chunk_size, cse_strategy=cse_strategy) + return generated + else: + # Traditional single kernel + model = element.get_stiffness_model() + return {model.name: MathCompiler.compile(model, target, + chunk_size=chunk_size, cse_strategy=cse_strategy)} + + # ========================================================================= + # 共享工具方法:表达式打印 & 索引替换 + # ========================================================================= + @staticmethod + def _print_expr(expr, printer): + """将 SymPy 表达式或原始值打印为字符串""" + if isinstance(expr, (int, float)): + return str(expr) + if isinstance(expr, str): + return expr + if isinstance(expr, sp.Basic): + return printer.doprint(expr) + return str(expr) + + @staticmethod + def _substitute_index(body, index_sym, value): + """将 body 中所有引用 index_sym 的表达式替换为具体值,返回新的 body 列表""" + import copy + new_body = [] + for stmt in body: + if isinstance(stmt, Assign): + new_expr = stmt.expr + new_target = stmt.target + if isinstance(new_expr, sp.Basic): + new_expr = new_expr.subs(index_sym, value) + if isinstance(new_target, sp.Basic): + new_target = new_target.subs(index_sym, value) + new_body.append(Assign(new_target, new_expr)) + + elif isinstance(stmt, Call): + new_input_exprs = [] + for e in stmt.input_exprs: + if isinstance(e, sp.Basic): + new_input_exprs.append(e.subs(index_sym, value)) + else: + new_input_exprs.append(e) + new_body.append(Call(stmt.model_name, new_input_exprs, stmt.output_vars)) + + elif isinstance(stmt, If): + new_cond = stmt.cond + if isinstance(new_cond, sp.Basic): + new_cond = new_cond.subs(index_sym, value) + new_then = MathCompiler._substitute_index(stmt.then_body, index_sym, value) + new_else = MathCompiler._substitute_index(stmt.else_body, index_sym, value) + new_body.append(If(new_cond, new_then, new_else)) + + elif isinstance(stmt, For): + # 不替换嵌套 For 的 index(不同循环变量),但替换 body 内引用外层 index 的部分 + new_start = stmt.start.subs(index_sym, value) if isinstance(stmt.start, sp.Basic) else stmt.start + new_end = stmt.end.subs(index_sym, value) if isinstance(stmt.end, sp.Basic) else stmt.end + new_body_inner = MathCompiler._substitute_index(stmt.body, index_sym, value) + new_body.append(For(stmt.index, new_start, new_end, new_body_inner, stmt.unroll)) + + else: + # BufferFill, BufferCopy, BufferAccum — 不含表达式,直接拷贝 + new_body.append(stmt) + + return new_body From 1aa473eb2305ee829741060555bde4ca8ed3feb1 Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Sat, 25 Apr 2026 00:01:52 +0800 Subject: [PATCH 13/14] =?UTF-8?q?refactor(cli):=20=E9=87=8D=E6=9E=84?= =?UTF-8?q?=E5=91=BD=E4=BB=A4=E8=A1=8C=E7=BB=93=E6=9E=84=E4=B8=BA=E5=AD=90?= =?UTF-8?q?=E5=91=BD=E4=BB=A4=E6=A8=A1=E5=BC=8F?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- codegen/cli.py | 331 +++++++++++++++++++++++++++---------------------- manual.md | 133 +++++++++++++++----- pyproject.toml | 2 +- 3 files changed, 281 insertions(+), 185 deletions(-) diff --git a/codegen/cli.py b/codegen/cli.py index e256c39..b877937 100644 --- a/codegen/cli.py +++ b/codegen/cli.py @@ -22,170 +22,78 @@ def _default_output(model_name: str, target: str) -> str: return f"{model_name}_gen.f90" return f"{model_name}_{t}.txt" -def main(): - parser = argparse.ArgumentParser( - description="SymPy FEA 代码生成器 (混合解耦架构)" - ) - parser.add_argument( - "--task", - required=True, - choices=["constitutive", "stiffness", "mass", "custom", "flow"], - help="生成任务: 'constitutive' (材料D矩阵), 'stiffness' (单元Ke矩阵), 'mass' (质量矩阵), 'custom' (自定义数学模型), 或 'flow' (流程模型)", - ) - parser.add_argument( - "--element", "-e", - help="单元名称 (e.g., 'tet4'), required for --task=stiffness", - ) - parser.add_argument( - "--material", "-m", - help="材料名称 (e.g., 'isotropic'), required for --task=constitutive", - ) - parser.add_argument( - "--script", "-s", - help="Python 脚本路径 (用于 --task=custom). 脚本中需要提供 get_model() 函数返回 MathModel.", - ) - parser.add_argument( - "--target", "-t", - required=True, - choices=["jax", "cpp", "cuda", "fortran", "all"], - help="目标语言:jax / cpp / cuda / fortran / all", - ) - parser.add_argument( - "--output", "-o", - default=None, - help="输出文件路径(默认根据任务和名称生成)", - ) - parser.add_argument( - "--chunk-size", - type=int, - default=None, - help="CSE chunk size. 如果省略,则使用 cse-strategy 决定。", - ) - parser.add_argument( - "--cse-strategy", - choices=["auto", "fixed"], - default="auto", - help="CSE chunk sizing 策略。'auto' 根据输出规模自动调整,'fixed' 使用固定默认值。", - ) - parser.add_argument( - "--test", - action="store_true", - default=False, - help="同时生成 CI 测试资产(C++/Fortran wrapper、test_driver.py、build 脚本)", - ) - parser.add_argument( - "--test-output-dir", - default=None, - help="测试资产输出目录(仅在 --test 启用时有效,默认与 --output 相同)", - ) - args = parser.parse_args() - if args.task == "constitutive": - if not args.material: - parser.error("--material is required for --task=constitutive") - material = load_material(args.material) - model = material.get_constitutive_model() - models_to_compile = {model.name: model} - - elif args.task == "stiffness": - if not args.element: - parser.error("--element is required for --task=stiffness") - element = load_element(args.element) - operators = element.get_stiffness_operators() - if operators: - models_to_compile = {op.name: op for op in operators} - else: - m = element.get_stiffness_model() - models_to_compile = {m.name: m} +def _load_script(script_path: Path, required_func: str, module_name: str): + """动态加载 Python 脚本并验证其包含指定函数。""" + if not script_path.exists(): + raise SystemExit(f"Error: Script file not found: {script_path}") + spec = importlib.util.spec_from_file_location(module_name, str(script_path)) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + if not hasattr(mod, required_func): + raise SystemExit(f"Error: Script {script_path} must define a '{required_func}()' function.") + return mod - elif args.task == "mass": - if not args.element: - parser.error("--element is required for --task=mass") - element = load_element(args.element) - operators = element.get_mass_operators() - if operators: - models_to_compile = {op.name: op for op in operators} - else: - # Fallback if no specific mass model is defined, though mass usually has operators - parser.error(f"No mass operators defined for element: {args.element}") - - elif args.task == "custom": - if not args.script: - parser.error("--script is required for --task=custom") - script_path = Path(args.script) - if not script_path.exists(): - parser.error(f"Script file not found: {script_path}") - - # Dynamically load the script - spec = importlib.util.spec_from_file_location("custom_script", str(script_path)) - custom_mod = importlib.util.module_from_spec(spec) - sys.modules["custom_script"] = custom_mod - spec.loader.exec_module(custom_mod) - - if not hasattr(custom_mod, "get_model"): - parser.error(f"Script {script_path} must define a 'get_model()' function.") - - models = custom_mod.get_model() + +def _cmd_compile(args): + """feacodegen compile — 编译 MathModel(类比 gcc -c)""" + models_to_compile = {} + task_name = None # 用于 test 资产命名 + + if args.source: + # 从脚本加载 MathModel + script_path = Path(args.source) + mod = _load_script(script_path, "get_model", "compile_script") + models = mod.get_model() if type(models).__name__ == "MathModel": models_to_compile = {models.name: models} elif isinstance(models, list) and all(type(m).__name__ == "MathModel" for m in models): models_to_compile = {m.name: m for m in models} else: - parser.error(f"get_model() must return a MathModel or a list of MathModels. Got: {type(models)}") - - elif args.task == "flow": - if not args.script: - parser.error("--script is required for --task=flow") - script_path = Path(args.script) - if not script_path.exists(): - parser.error(f"Script file not found: {script_path}") - - # Dynamically load the script - spec = importlib.util.spec_from_file_location("flow_script", str(script_path)) - flow_mod = importlib.util.module_from_spec(spec) - sys.modules["flow_script"] = flow_mod - spec.loader.exec_module(flow_mod) - - if not hasattr(flow_mod, "get_flow_model"): - parser.error(f"Script {script_path} must define a 'get_flow_model()' function.") - - flow = flow_mod.get_flow_model() - if not isinstance(flow, FlowModel): - parser.error(f"get_flow_model() must return a FlowModel. Got: {type(flow)}") - - # FlowModel 使用单独的编译路径 - target = args.target - if target == "all": - targets = ["cpp", "cuda", "fortran", "jax"] - else: - targets = [target] + raise SystemExit(f"Error: get_model() must return a MathModel or a list of MathModels. Got: {type(models)}") + task_name = "compile" - for t in targets: - code = FEACompiler.compile_flow( - flow, t, - chunk_size=args.chunk_size, - cse_strategy=args.cse_strategy, - ) - out_path = Path(args.output or ".") / _default_output(flow.name, t) - out_path.parent.mkdir(parents=True, exist_ok=True) - with open(out_path, "w", encoding="utf-8") as f: - f.write(code) - print(f"Generated: {out_path}") + elif args.material: + # 从内置材料加载 + material = load_material(args.material) + model = material.get_constitutive_model() + models_to_compile = {model.name: model} + task_name = "constitutive" - return + elif args.element: + # 从内置单元加载 + element = load_element(args.element) + if args.task == "mass": + operators = element.get_mass_operators() + if operators: + models_to_compile = {op.name: op for op in operators} + else: + raise SystemExit(f"Error: No mass operators defined for element: {args.element}") + else: + # 默认 stiffness + operators = element.get_stiffness_operators() + if operators: + models_to_compile = {op.name: op for op in operators} + else: + m = element.get_stiffness_model() + models_to_compile = {m.name: m} + task_name = args.task or "stiffness" - # ---------------- Compile Models ---------------- + else: + raise SystemExit("Error: Must specify one of: , --material, or --element") + + # 编译 & 输出 base_test_dir = Path(args.test_output_dir or args.output or ".") if args.test else None for name, model in models_to_compile.items(): if args.target == "all": - # --target all: 使用 compile_all 实现真正的共享 CSE generated = FEACompiler.compile_all( model, chunk_size=args.chunk_size, cse_strategy=args.cse_strategy, test=args.test, - task=args.task, + task=task_name, model_name=args.material or args.element or name, ) for t, code in generated.items(): @@ -195,7 +103,6 @@ def main(): with open(out_path, "w", encoding="utf-8") as f: f.write(code) print(f"Generated: {out_path}") - # Also copy kernel source to test directory if --test is enabled if args.test and base_test_dir is not None: kernel_dir = base_test_dir / name kernel_dir.mkdir(parents=True, exist_ok=True) @@ -205,7 +112,6 @@ def main(): with open(kernel_path, "w", encoding="utf-8") as f: f.write(code) elif args.test and t in ("cpp_wrapper", "f90_wrapper", "test_driver", "build_sh", "build_bat"): - # Each model gets its own subdirectory to avoid overwriting test_dir = base_test_dir / name test_dir.mkdir(parents=True, exist_ok=True) fname_map = { @@ -219,11 +125,10 @@ def main(): with open(out_path, "w", encoding="utf-8") as f: f.write(code) print(f"Generated: {out_path}") - # Generate codegen_meta.json for test_driver.py to locate sympy_codegen if args.test and base_test_dir is not None: test_dir = base_test_dir / name test_dir.mkdir(parents=True, exist_ok=True) - code_gen_dir = Path(__file__).resolve().parent.parent # codegen/ -> fea_codegen/ + code_gen_dir = Path(__file__).resolve().parent.parent rel_path = os.path.relpath(code_gen_dir, test_dir.resolve()) meta = {"code_gen_rel_path": rel_path} meta_path = test_dir / "codegen_meta.json" @@ -231,7 +136,6 @@ def main(): _json.dump(meta, f, indent=2) print(f"Generated: {meta_path}") else: - # 单一目标编译 code = FEACompiler.compile( model, args.target, @@ -239,6 +143,133 @@ def main(): cse_strategy=args.cse_strategy, ) out_path = Path(args.output or ".") / _default_output(name, args.target) + out_path.parent.mkdir(parents=True, exist_ok=True) with open(out_path, "w", encoding="utf-8") as f: f.write(code) print(f"Generated: {out_path}") + + +def _cmd_link(args): + """feacodegen link — 编译 + 链接 FlowModel(类比 gcc 编译项目 + 链接)""" + if not args.source: + raise SystemExit("Error: script path is required for link") + + script_path = Path(args.source) + mod = _load_script(script_path, "get_flow_model", "link_script") + flow = mod.get_flow_model() + if not isinstance(flow, FlowModel): + raise SystemExit(f"Error: get_flow_model() must return a FlowModel. Got: {type(flow)}") + + target = args.target + if target == "all": + targets = ["cpp", "cuda", "fortran", "jax"] + else: + targets = [target] + + for t in targets: + code = FEACompiler.compile_flow( + flow, t, + chunk_size=args.chunk_size, + cse_strategy=args.cse_strategy, + ) + out_path = Path(args.output or ".") / _default_output(flow.name, t) + out_path.parent.mkdir(parents=True, exist_ok=True) + with open(out_path, "w", encoding="utf-8") as f: + f.write(code) + print(f"Generated: {out_path}") + + +def _add_compile_options(parser): + """为 compile 子命令添加通用编译选项。""" + parser.add_argument( + "--target", "-t", + required=True, + choices=["jax", "cpp", "cuda", "fortran", "all"], + help="目标语言:jax / cpp / cuda / fortran / all", + ) + parser.add_argument( + "--output", "-o", + default=None, + help="输出文件路径(默认根据任务和名称生成)", + ) + parser.add_argument( + "--chunk-size", + type=int, + default=None, + help="CSE chunk size. 如果省略,则使用 cse-strategy 决定。", + ) + parser.add_argument( + "--cse-strategy", + choices=["auto", "fixed"], + default="auto", + help="CSE chunk sizing 策略。'auto' 根据输出规模自动调整,'fixed' 使用固定默认值。", + ) + parser.add_argument( + "--test", + action="store_true", + default=False, + help="同时生成 CI 测试资产(C++/Fortran wrapper、test_driver.py、build 脚本)", + ) + parser.add_argument( + "--test-output-dir", + default=None, + help="测试资产输出目录(仅在 --test 启用时有效,默认与 --output 相同)", + ) + + +def main(): + parser = argparse.ArgumentParser( + prog="feacodegen", + description="FEA 代码生成器 — 将高层级数学定义自动转换为优化的 C++/CUDA/Fortran/JAX 计算内核", + ) + subparsers = parser.add_subparsers(dest="command", help="子命令") + + # ─── compile 子命令 ─── + compile_parser = subparsers.add_parser( + "compile", + help="编译 MathModel(类比 gcc -c,编译单个数学模型源文件)", + description="编译 MathModel → 目标语言代码。可从 Python 脚本、内置材料或内置单元加载模型。", + ) + compile_parser.add_argument( + "source", + nargs="?", + default=None, + help="Python 脚本路径(脚本需定义 get_model() 返回 MathModel)", + ) + compile_parser.add_argument( + "--material", "-m", + help="内置材料名称(如 isotropic)", + ) + compile_parser.add_argument( + "--element", "-e", + help="内置单元名称(如 tet4, hex8r)", + ) + compile_parser.add_argument( + "--task", + choices=["stiffness", "mass"], + default="stiffness", + help="对内置单元的操作类型(默认 stiffness,仅 --element 时有效)", + ) + _add_compile_options(compile_parser) + compile_parser.set_defaults(func=_cmd_compile) + + # ─── link 子命令 ─── + link_parser = subparsers.add_parser( + "link", + help="编译 + 链接 FlowModel(类比 gcc 编译项目,先编译子模型再链接主流程)", + description="编译 FlowModel → 目标语言代码。先编译各子模型函数,再生成主流程函数并拼接。", + ) + link_parser.add_argument( + "source", + help="Python 脚本路径(脚本需定义 get_flow_model() 返回 FlowModel)", + ) + _add_compile_options(link_parser) + link_parser.set_defaults(func=_cmd_link) + + args = parser.parse_args() + + if args.command is None: + parser.print_help() + sys.exit(1) + + args.func(args) diff --git a/manual.md b/manual.md index 608c211..c906b42 100644 --- a/manual.md +++ b/manual.md @@ -4,56 +4,121 @@ 本工具将高层级数学定义自动转换为优化的 C++ / CUDA / Fortran / JAX 计算内核。 -支持五类生成任务: +设计理念参考 gcc 编译器模型: -| 任务 | 说明 | 必需参数 | -|------|------|----------| -| `constitutive` | 材料 D 矩阵 | `--material` | -| `stiffness` | 单元刚度矩阵 Ke | `--element` | -| `mass` | 单元质量矩阵 | `--element` | -| `custom` | 自定义数学模型 | `--script` | -| `flow` | 流程模型 | `--script` | +| 类比 | feacodegen | gcc | +|------|-----------|-----| +| 编译单个源文件 | `feacodegen compile` | `gcc -c file.c` | +| 编译项目 + 链接 | `feacodegen link` | `gcc file1.o file2.o -o prog` | --- ## 2. 命令行接口 ```bash -python sympy_codegen.py --task <任务> --target <语言> [选项] +feacodegen [options] ``` -### 2.1 参数一览 +### 2.1 子命令一览 -| 参数 | 缩写 | 必需 | 说明 | -|------|------|------|------| -| `--task` | | 是 | 生成任务:`constitutive` / `stiffness` / `mass` / `custom` / `flow` | -| `--target` | `-t` | 是 | 目标语言:`cpp` / `cuda` / `fortran` / `jax` / `all` | -| `--material` | `-m` | 条件 | 材料名称(`--task constitutive` 时必需) | -| `--element` | `-e` | 条件 | 单元名称(`--task stiffness|mass` 时必需) | -| `--script` | `-s` | 条件 | Python 脚本路径(`--task custom|flow` 时必需) | -| `--output` | `-o` | 否 | 输出路径,默认自动生成 | -| `--chunk-size` | | 否 | CSE 分块大小,省略则由策略自动决定 | -| `--cse-strategy` | | 否 | `auto`(默认)或 `fixed` | -| `--test` | | 否 | 同时生成 CI 测试资产 | -| `--test-output-dir` | | 否 | 测试资产输出目录,默认与 `--output` 相同 | +| 子命令 | 说明 | 类比 | +|--------|------|------| +| `compile` | 编译 MathModel → 目标代码 | `gcc -c`(编译单个源文件) | +| `link` | 编译 + 链接 FlowModel → 目标代码 | `gcc` 编译项目 + 链接 | + +--- + +### 2.2 `feacodegen compile` — 编译 MathModel + +将 MathModel(声明式数学模型)编译为目标语言代码。 + +```bash +feacodegen compile -t [options] +feacodegen compile --material -t [options] +feacodegen compile --element -t [options] +``` + +**输入来源**(三选一,优先级:source > --material > --element): + +| 来源 | 说明 | 必需参数 | +|------|------|----------| +| 脚本路径 `` | Python 脚本,需定义 `get_model()` 返回 `MathModel` | 位置参数 | +| `--material` | 内置材料名称(如 `isotropic`) | `-m` | +| `--element` | 内置单元名称(如 `tet4`, `hex8r`) | `-e`,可选 `--task` | + +#### compile 专用参数 + +| 参数 | 缩写 | 说明 | +|------|------|------| +| `` | | Python 脚本路径(位置参数,可选) | +| `--material` | `-m` | 内置材料名称 | +| `--element` | `-e` | 内置单元名称 | +| `--task` | | 对内置单元的操作类型:`stiffness`(默认)或 `mass`,仅 `--element` 时有效 | + +#### 典型用法 + +```bash +# 从脚本编译 MathModel +feacodegen compile my_model.py -t cpp + +# 从内置材料编译 D 矩阵 +feacodegen compile --material isotropic -t cpp -### 2.2 典型用法 +# 从内置单元编译刚度算子 +feacodegen compile --element tet4 -t all + +# 从内置单元编译质量算子 +feacodegen compile --element hex8r --task mass -t cuda + +# 编译到全平台 +feacodegen compile my_model.py -t all +``` + +--- + +### 2.3 `feacodegen link` — 编译 + 链接 FlowModel + +将 FlowModel(命令式流程模型)编译为目标语言代码。编译器会先编译所有子模型函数,再生成主流程函数并拼接。 ```bash -# 材料 D 矩阵 → C++ -python sympy_codegen.py --task constitutive --material isotropic --target cpp +feacodegen link -t [options] +``` + +#### link 专用参数 + +| 参数 | 缩写 | 说明 | +|------|------|------| +| `` | | Python 脚本路径(必需,脚本需定义 `get_flow_model()` 返回 `FlowModel`) | + +#### 典型用法 -# 单元刚度算子 → 全平台 -python sympy_codegen.py --task stiffness --element tet4 --target all +```bash +# 编译 FlowModel → C++ +feacodegen link my_flow.py -t cpp -# 自定义模型 → JAX -python sympy_codegen.py --task custom --script my_model.py --target jax +# 编译 FlowModel → CUDA +feacodegen link my_flow.py -t cuda -# 流程模型 → CUDA -python sympy_codegen.py --task flow --script my_flow.py --target cuda +# 编译 FlowModel → 全平台 +feacodegen link my_flow.py -t all ``` -### 2.3 输出文件命名 +--- + +### 2.4 通用编译选项 + +以下选项同时适用于 `compile` 和 `link`: + +| 参数 | 缩写 | 必需 | 说明 | +|------|------|------|------| +| `--target` | `-t` | 是 | 目标语言:`cpp` / `cuda` / `fortran` / `jax` / `all` | +| `--output` | `-o` | 否 | 输出路径,默认自动生成 | +| `--chunk-size` | | 否 | CSE 分块大小,省略则由策略自动决定 | +| `--cse-strategy` | | 否 | `auto`(默认)或 `fixed` | +| `--test` | | 否 | 同时生成 CI 测试资产(仅 compile) | +| `--test-output-dir` | | 否 | 测试资产输出目录,默认与 `--output` 相同 | + +### 2.5 输出文件命名 | 目标 | 扩展名 | |------|--------| @@ -336,7 +401,7 @@ def get_flow_model(): 然后生成代码: ```bash -python sympy_codegen.py --task flow --script my_flow.py --target all +feacodegen link my_flow.py -t all ``` ### 4.5 各目标平台的语义映射 @@ -364,7 +429,7 @@ python sympy_codegen.py --task flow --script my_flow.py --target all 添加 `--test` 标志即可在生成内核的同时生成完整的交叉验证测试套件: ```bash -python sympy_codegen.py --task stiffness --element tet4 --target all --test --output generated/tet4 +feacodegen compile --element tet4 -t all --test --output generated/tet4 ``` 生成的测试目录结构: diff --git a/pyproject.toml b/pyproject.toml index 301ae9e..58bb2a3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,4 +25,4 @@ dev = [ ] [project.scripts] -fea-codegen = "sympy_codegen:main" +feacodegen = "sympy_codegen:main" From 1f1ef6146af2faa3640f1f8c0dd67775be91bedc Mon Sep 17 00:00:00 2001 From: "xiaotong.wang" <18648483389@163.com> Date: Sat, 25 Apr 2026 00:14:33 +0800 Subject: [PATCH 14/14] =?UTF-8?q?build:=20=E6=9B=B4=E6=96=B0=20pyproject.t?= =?UTF-8?q?oml=20=E6=9E=84=E5=BB=BA=E9=85=8D=E7=BD=AE?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pyproject.toml | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 58bb2a3..80da217 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,13 +1,13 @@ [build-system] requires = ["setuptools>=68.0"] -build-backend = "setuptools.backends._legacy:_Backend" +build-backend = "setuptools.build_meta" [project] name = "fea-codegen" version = "0.1.0" description = "Finite Element Analysis Code Generator — SymPy-based kernel generation for C++/CUDA/Fortran/JAX" readme = "README.md" -license = {text = "MPL-2.0"} +license = "MPL-2.0" requires-python = ">=3.10" dependencies = [ "sympy>=1.12", @@ -26,3 +26,7 @@ dev = [ [project.scripts] feacodegen = "sympy_codegen:main" + +[tool.setuptools] +packages = ["codegen", "definitions"] +py-modules = ["sympy_codegen"]