Skip to content

Commit a01af9f

Browse files
authored
Merge pull request #672 from stacks-network/feat/use-duck-typing-for-fold
Duck-typing Part III: Fold
2 parents 9a0c9ad + 9af2066 commit a01af9f

File tree

2 files changed

+91
-78
lines changed

2 files changed

+91
-78
lines changed

clar2wasm/src/lib.rs

Lines changed: 1 addition & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ use clarity::vm::analysis::{run_analysis, AnalysisDatabase, ContractAnalysis};
55
use clarity::vm::ast::{build_ast_with_diagnostics, ContractAST};
66
use clarity::vm::costs::{ExecutionCost, LimitedCostTracker};
77
use clarity::vm::diagnostic::Diagnostic;
8-
use clarity::vm::types::{
9-
FixedFunction, ListTypeData, QualifiedContractIdentifier, SequenceSubtype, TypeSignature,
10-
};
8+
use clarity::vm::types::QualifiedContractIdentifier;
119
use clarity::vm::ClarityVersion;
1210
pub use walrus::Module;
1311
use wasm_generator::{GeneratorError, WasmGenerator};
@@ -105,8 +103,6 @@ pub fn compile(
105103
}
106104
};
107105

108-
typechecker_workaround(&ast, &mut contract_analysis);
109-
110106
// Now that the typechecker pass is done, we can concretize the expressions types which
111107
// might contain `ListUnionType` or `CallableType`
112108
#[allow(clippy::expect_used)]
@@ -154,65 +150,6 @@ pub fn compile(
154150
}
155151
}
156152

157-
// Workarounds to make filter/fold work in cases where it would not otherwise. see issue #488
158-
fn typechecker_workaround(ast: &ContractAST, contract_analysis: &mut ContractAnalysis) {
159-
for expr in ast.expressions.iter() {
160-
match expr
161-
.match_list()
162-
.and_then(|l| l.first())
163-
.and_then(|first| first.match_atom())
164-
.map(|atom| atom.as_str())
165-
{
166-
Some("fold") => {
167-
// in the case of fold we need to override the type of the argument list
168-
169-
let Some(func_expr) = expr.match_list().map(|l| &l[1]) else {
170-
continue;
171-
};
172-
173-
let Some(func_name) = func_expr.match_atom() else {
174-
continue;
175-
};
176-
177-
let return_type = match contract_analysis
178-
.get_private_function(func_name.as_str())
179-
.or(contract_analysis.get_read_only_function_type(func_name.as_str()))
180-
{
181-
Some(clarity::vm::types::FunctionType::Fixed(FixedFunction {
182-
args, ..
183-
})) => args[0].signature.clone(),
184-
_ => continue,
185-
};
186-
187-
let Some(sequence_expr) = expr.match_list().map(|l| &l[2]) else {
188-
continue;
189-
};
190-
191-
if let Some(tmap) = contract_analysis.type_map.as_mut() {
192-
let Some(seq_type) = tmap.get_type(sequence_expr) else {
193-
continue;
194-
};
195-
let TypeSignature::SequenceType(SequenceSubtype::ListType(data)) = seq_type
196-
else {
197-
continue;
198-
};
199-
200-
let Ok(list_data) = ListTypeData::new_list(return_type, data.get_max_len())
201-
else {
202-
continue;
203-
};
204-
205-
tmap.overwrite_type(
206-
sequence_expr,
207-
TypeSignature::SequenceType(SequenceSubtype::ListType(list_data)),
208-
);
209-
}
210-
}
211-
_ => continue,
212-
}
213-
}
214-
}
215-
216153
pub fn compile_contract(contract_analysis: ContractAnalysis) -> Result<Module, GeneratorError> {
217154
let generator = WasmGenerator::new(contract_analysis)?;
218155
generator.generate()

clar2wasm/src/words/sequences.rs

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use clarity::vm::clarity_wasm::get_type_size;
22
use clarity::vm::types::{
3-
FunctionType, ListTypeData, SequenceSubtype, StringSubtype, TypeSignature,
3+
FixedFunction, FunctionType, ListTypeData, SequenceSubtype, StringSubtype, TypeSignature,
44
};
55
use clarity::vm::{ClarityName, SymbolicExpression};
66
use walrus::ir::{self, BinaryOp, IfElse, InstrSeqType, Loop, UnaryOp};
@@ -90,7 +90,7 @@ impl ComplexWord for Fold {
9090
&self,
9191
generator: &mut WasmGenerator,
9292
builder: &mut walrus::InstrSeqBuilder,
93-
_expr: &SymbolicExpression,
93+
expr: &SymbolicExpression,
9494
args: &[SymbolicExpression],
9595
) -> Result<(), GeneratorError> {
9696
check_args!(generator, builder, 3, args.len(), ArgumentCountCheck::Exact);
@@ -112,21 +112,60 @@ impl ComplexWord for Fold {
112112
// (- 6 (- 4 (- 2 0)))
113113
// ```
114114

115-
// WORKAROUND: Get the type of the function being called, and set the
116-
// type of the initial value to match the functions parameter type.
117-
// This is a workaround for the typechecker not being able to infer
118-
// the complete type of initial value.
119-
if let Some(FunctionType::Fixed(fixed)) = generator.get_function_type(func) {
120-
let initial_ty = fixed
121-
.args
122-
.get(1)
115+
// To make sure that the initial value will reserve enough space in memory, we reassign its type to the type of the expression.
116+
generator.set_expr_type(
117+
initial,
118+
generator
119+
.get_expr_type(expr)
123120
.ok_or_else(|| {
124-
GeneratorError::TypeError("expected function with 2 arguments".into())
121+
GeneratorError::TypeError("fold expression should be typed".to_owned())
125122
})?
126-
.signature
127-
.clone();
128-
generator.set_expr_type(initial, initial_ty)?;
123+
.clone(),
124+
)?;
125+
126+
// We need to find the correct types expected by the function `func` and the result type of the fold expression
127+
// to make sure everything will be coherent in the end.
128+
// This is only needed if we are folding a list and the function is user-defined.
129+
struct FoldFuncTy {
130+
elem_ty: TypeSignature,
131+
acc_ty: TypeSignature,
132+
return_ty: TypeSignature,
129133
}
134+
let fold_func_ty = {
135+
match generator.get_expr_type(sequence).ok_or_else(|| {
136+
GeneratorError::TypeError("Folded sequence should be typed".to_owned())
137+
})? {
138+
TypeSignature::SequenceType(SequenceSubtype::ListType(ltd)) => {
139+
match generator.get_function_type(func) {
140+
Some(FunctionType::Fixed(FixedFunction { args, returns }))
141+
if args.len() == 2 =>
142+
{
143+
let fold_func_ty = FoldFuncTy {
144+
elem_ty: args[0].signature.clone(),
145+
acc_ty: args[1].signature.clone(),
146+
return_ty: returns.clone(),
147+
};
148+
// Set the type of the list elements
149+
generator.set_expr_type(
150+
sequence,
151+
TypeSignature::SequenceType(SequenceSubtype::ListType(
152+
ListTypeData::new_list(
153+
fold_func_ty.elem_ty.clone(),
154+
ltd.get_max_len(),
155+
)
156+
.map_err(|e| GeneratorError::TypeError(e.to_string()))?,
157+
)),
158+
)?;
159+
// set the accumulator type
160+
generator.set_expr_type(initial, fold_func_ty.acc_ty.clone())?;
161+
Some(fold_func_ty)
162+
}
163+
_ => None,
164+
}
165+
}
166+
_ => None,
167+
}
168+
};
130169

131170
// The result type must match the type of the initial value
132171
let result_clar_ty = generator
@@ -227,6 +266,10 @@ impl ComplexWord for Fold {
227266
} else {
228267
// Call user defined function
229268
generator.visit_call_user_defined(&mut loop_, &result_clar_ty, func)?;
269+
// since the accumulator and the return type of the function could have different types, we need to duck-type.
270+
if let Some(tys) = &fold_func_ty {
271+
generator.duck_type(&mut loop_, &tys.return_ty, &tys.acc_ty)?;
272+
}
230273
}
231274
// Save the result into the locals (in reverse order as we pop)
232275
for result_local in result_locals.iter().rev() {
@@ -262,6 +305,11 @@ impl ComplexWord for Fold {
262305
alternative: else_id,
263306
});
264307

308+
// since the return type of the function and the accumulator could have different types, we need to duck-type.
309+
if let Some(tys) = &fold_func_ty {
310+
generator.duck_type(builder, &tys.acc_ty, &tys.return_ty)?;
311+
}
312+
265313
Ok(())
266314
}
267315
}
@@ -2354,6 +2402,34 @@ mod tests {
23542402
crosscheck(snippet, Ok(Some(expected)))
23552403
}
23562404

2405+
#[test]
2406+
fn fold_with_response_partial_acc() {
2407+
let snippet = "
2408+
(define-private (foo (a (response int bool)) (b (response int uint)))
2409+
(match b
2410+
bok (ok (+ (unwrap-panic a) bok))
2411+
berr (ok (+ (unwrap-panic a) (to-int berr)))
2412+
)
2413+
)
2414+
(fold foo (list (ok 1)) (ok 2))
2415+
";
2416+
crosscheck(snippet, Ok(Some(Value::okay(Value::Int(3)).unwrap())));
2417+
}
2418+
2419+
#[test]
2420+
fn fold_with_response_full_acc() {
2421+
let snippet = "
2422+
(define-private (foo (a (response int bool)) (b (response int uint)))
2423+
(match b
2424+
bok (ok (+ (unwrap-panic a) bok))
2425+
berr (err (+ (to-uint (unwrap-panic a)) berr))
2426+
)
2427+
)
2428+
(fold foo (list (ok 1)) (ok 2))
2429+
";
2430+
crosscheck(snippet, Ok(Some(Value::okay(Value::Int(3)).unwrap())));
2431+
}
2432+
23572433
#[test]
23582434
fn unit_fold_repsonses_full_type() {
23592435
let snippet = "

0 commit comments

Comments
 (0)