commit abe9dca122aaec0f6bb88a80b5733f46c06da088
parent 33306912a82cb2efdc93b2e25d84704c085100e3
Author: m21c <ho*******@gmail.com>
Date: Sat, 3 Apr 2021 21:39:14 +0200
worked on type-checking & folding + corrected coding style
Diffstat:
| M | aria.c | | | 265 | ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------- |
1 file changed, 228 insertions(+), 37 deletions(-)
diff --git a/aria.c b/aria.c
@@ -1078,6 +1078,8 @@ enum DeclKind {
struct Decl {
DeclKind kind;
+ Type *type;
+
Env *env;
Node *declnode;
@@ -1150,6 +1152,7 @@ makedecl(int key, DeclKind kind) {
decl->kind = kind;
decl->key = key;
+ decl->type = prim + TVOID;
assert(currenv);
@@ -1568,6 +1571,7 @@ declaration(Node *typenode) {
if (tok.kind == 'I') {
Decl *decl = makedecl(tok.u.key, DVAR);
+ decl->type = typenode->type;
result = makenode(typenode);
result->kind = ADECL;
result->rhs = makenode(NULL);
@@ -1792,6 +1796,7 @@ atom(int flags) {
if (lhs->u.declref) {
lhs->kind = ADECLREF;
+ lhs->type = lhs->u.declref->type;
} else {
Env *funcenv = getfuncenv();
if (funcenv) {
@@ -2077,31 +2082,118 @@ exprlist(bool isparam, Node *paramtype) {
/* - type-checking & folding - */
+bool
+isinttype(Type *ty) {
+ switch (ty->kind) {
+ case TU8: case TS8:
+ case TU16: case TS16:
+ case TU32: case TS32:
+ case TU64: case TS64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool
+isfloattype(Type *ty) {
+ switch (ty->kind) {
+ case TF32: case TF64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool
+isarithtype(Type *ty) {
+ switch (ty->kind) {
+ case TU8: case TS8:
+ case TU16: case TS16:
+ case TU32: case TS32:
+ case TU64: case TS64:
+ case TF32: case TF64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+bool
+isunsignedtype(Type *ty) {
+ switch (ty->kind) {
+ case TU8: case TU16:
+ case TU32: case TU64:
+ return true;
+ default:
+ return false;
+ }
+}
+
+/* TODO(m21c): also mask int/float values in the tokenizer */
+uint64_t
+maskint(int size, uint64_t value) {
+ if (size == 1) return value & 0xfful;
+ if (size == 2) return value & 0xfffful;
+ if (size == 4) return value & 0xfffffffful;
+ return value;
+}
+
+double
+maskfloat(int size, double value) {
+ if (size == 4) return (double) (float) value;
+ return value;
+}
+
+uint64_t
+convint(int srcsize, bool srcsigned, uint64_t value) {
+ if (!srcsigned) return value;
+ if (srcsize == 1) return (uint64_t) (int8_t ) value;
+ if (srcsize == 2) return (uint64_t) (int16_t) value;
+ if (srcsize == 4) return (uint64_t) (int32_t) value;
+ return value;
+}
+
Node *
-conv(Node *node)
-{
+conv(Node *node) {
return node;
}
Node *
-wrap(Type *ty, Node *node)
-{
+wrap(Type *ty, Node *node) {
if (node->type && ty->kind == node->type->kind)
return node;
- if (node->kind == 'N')
- return node->type = ty, node;
+
+ if (node->kind == 'N') {
+ /* TODO(m21c): layout correct type-conversions ? */
+ if (isfloattype(node->type)) {
+ if (isfloattype(ty))
+ node->u.d = maskfloat(ty->size, node->u.d);
+ else if (isinttype(ty))
+ node->u.u = maskint(ty->size, (int64_t) node->u.d);
+ } else if (isinttype(node->type)) {
+ if (isfloattype(ty)) {
+ node->u.d = maskfloat(ty->size, (double)
+ (int64_t) convint(node->type->size,
+ !isunsignedtype(node->type),
+ node->u.u));
+ } else if (isinttype(ty)) {
+ node->u.u = maskint(ty->size,
+ convint(node->type->size,
+ !isunsignedtype(node->type),
+ node->u.u));
+ }
+ }
+ node->type = ty;
+ return node;
+ }
+
node = makenode(node);
node->kind = ACONV;
node->type = ty;
return node;
}
-Type *
-usualarithconv(Type *lt, Type *rt)
-{
- return lt;
-}
-
typedef
Node *(*RuleFunc)(Node *expr);
@@ -2109,13 +2201,13 @@ Node *
foldexpr(Node *expr);
Node *
-identrule(Node *ident)
-{
+identrule(Node *ident) {
Decl *declref = finddeclaration(ident->u.key);
if (declref) {
ident->kind = ADECLREF;
ident->u.declref = declref;
+ ident->type = declref->type;
} else {
error("'%s' undeclared", getstring(idents, ident->u.key));
}
@@ -2124,8 +2216,7 @@ identrule(Node *ident)
}
Node *
-binaryarithrule(Node *binary)
-{
+binaryarithrule(Node *binary) {
Node *lhs = binary->lhs;
Node *rhs = binary->rhs;
Type *tt;
@@ -2135,16 +2226,101 @@ binaryarithrule(Node *binary)
lhs = foldexpr(lhs);
rhs = foldexpr(rhs);
- tt = usualarithconv(lhs->type, rhs->type);
+
+ /* usual arithmetic conversion */
+ if (isarithtype(lhs->type) && isarithtype(rhs->type)) {
+ if (lhs->type->kind < rhs->type->kind)
+ tt = rhs->type;
+ else
+ tt = lhs->type;
+ } else {
+ tt = prim + TVOID;
+ }
+
lhs = wrap(tt, lhs);
rhs = wrap(tt, rhs);
binary->type = tt;
+ #define evalbinary(op) do { \
+ binary->kind = 'N'; \
+ if (isfloattype(tt)) \
+ binary->u.d = maskfloat(tt->size, \
+ maskfloat(tt->size, lhs->u.d) op \
+ maskfloat(tt->size, rhs->u.d) \
+ ); \
+ else if (isinttype(tt)) \
+ binary->u.u = maskint(tt->size, \
+ maskint(tt->size, lhs->u.u) op \
+ maskint(tt->size, rhs->u.u) \
+ ); \
+ /* delete(lhs); delete(rhs) */ \
+ } while (0)
+
+ #define isvalue(expr, value) (expr->kind == 'N' && \
+ ((expr->u.u == value && isinttype(tt)) || \
+ (expr->u.d == value && isarithtype(tt))))
+
+ switch (binary->kind) {
+ case OADD: case OSUB:
+ if (lhs->kind == 'N' && rhs->kind == 'N') {
+ if (binary->kind == OADD) evalbinary(+);
+ else evalbinary(-);
+ } else if (isvalue(lhs, 0)) {
+ if (binary->kind == OADD) {
+ *binary = *rhs;
+ /* delete(lhs); delete(rhs) */
+ } else {
+ binary->kind = OMINUS;
+ binary->lhs = rhs;
+ /* delete(lhs) */
+ }
+ } else if (isvalue(rhs, 0)) {
+ *binary = *lhs;
+ /* delete(lhs); delete(rhs) */
+ }
+ break;
+ case OMUL: case ODIV: case OMOD:
+ if (lhs->kind == 'N' && rhs->kind == 'N') {
+ if (binary->kind == OMUL) {
+ evalbinary(*);
+ } else {
+ if (rhs->u.u == 0 && isinttype(tt))
+ error("division by zero");
+ else if (binary->kind == ODIV)
+ evalbinary(/);
+ else
+ evalbinary(-);
+ }
+ } else if (isvalue(lhs, 0)) {
+ *binary = *lhs;
+ /* delete(lhs); delete(rhs) */
+ } else if (binary->kind == OMUL && isvalue(rhs, 0)) {
+ *binary = *rhs;
+ /* delete(lhs); delete(rhs) */
+ } else if (isvalue(rhs, 0)) {
+ if (rhs->u.u == 0 && isinttype(tt))
+ error("division by zero");
+ *binary = *rhs;
+ /* delete(lhs); delete(rhs) */
+ } else if (isvalue(lhs, 1)) {
+ *binary = *rhs;
+ /* delete(lhs); delete(rhs) */
+ } else if (binary->kind == OMUL && isvalue(rhs, 1)) {
+ *binary = *lhs;
+ /* delete(lhs); delete(rhs) */
+ }
+ }
+
binary->lhs = lhs;
binary->rhs = rhs;
return binary;
}
+Node *
+unaryarithrule(Node *unary) {
+ return unary;
+}
+
RuleFunc ruletable[] = {
['I'] = &identrule,
@@ -2200,8 +2376,7 @@ RuleFunc opfunctable[] = {
};
Node *
-oprules(Node *expr)
-{
+oprules(Node *expr) {
return opfunctable[expr->u.id](expr);
}
@@ -2226,16 +2401,14 @@ RuleFunc astfunctable[] = {
};
Node *
-astrules(Node *expr)
-{
+astrules(Node *expr) {
return astfunctable[expr->u.id](expr);
}
#endif
Node *
-foldexpr(Node *expr)
-{
+foldexpr(Node *expr) {
Node *c, *n;
for (c = expr; c; c = c->next) {
@@ -2262,7 +2435,7 @@ foldexpr(Node *expr)
c->lhs = foldexpr(c->lhs);
if (c->rhs)
c->rhs = foldexpr(c->rhs);
-#else
+#elif 1
if (ruletable[c->kind]) {
n = ruletable[c->kind](c);
@@ -2273,6 +2446,27 @@ foldexpr(Node *expr)
c = n;
}
}
+#else
+ Node *lhs, *rhs;
+ Type *lt, *lt, *tt;
+
+ #define convertbinaryarith() \
+ lhs = foldexpr(c->lhs); \
+ rhs = foldexpr(c->rhs); \
+ tt = usualarithconv(lhs->type, rhs->type); \
+ lhs = wrap(tt, lhs); \
+ rhs = wrap(tt, rhs); \
+ c->type = tt
+
+
+ switch (c->kind) {
+ case OMUL: case ODIV: case OMOD:
+ case OADD: case OSUB:
+ convertbinaryarith();
+ break;
+ default:
+ break;
+ }
#endif
}
@@ -2284,8 +2478,7 @@ foldexpr(Node *expr)
int
printexpr(FILE *out, Node *expr, int indent);
-int printtype(FILE *out, Type *type, int indent)
-{
+int printtype(FILE *out, Type *type, int indent) {
int n = 0;
if (!type)
@@ -2311,11 +2504,12 @@ int printtype(FILE *out, Type *type, int indent)
#undef typecase
default:;
}
+
+ return n;
}
bool
-isclauseorempty(Node *expr)
-{
+isclauseorempty(Node *expr) {
Kind kind;
while (expr && (expr->kind == ASCOPE || expr->kind == ASTMT))
@@ -2330,8 +2524,7 @@ isclauseorempty(Node *expr)
}
int
-printclause(FILE *out, Node *expr, int indent)
-{
+printclause(FILE *out, Node *expr, int indent) {
int n = 0;
if (isclauseorempty(expr)) {
@@ -2346,11 +2539,11 @@ printclause(FILE *out, Node *expr, int indent)
}
int
-printstring(FILE *out, Node *string)
-{
+printstring(FILE *out, Node *string) {
const char *str = getstring(strings, string->u.key);
int len = getlength(strings, string->u.key);
int i, n = fprintf(out, "\"");
+
for (i = 0; i < len; ++i) {
switch (str[i]) {
case '\\':
@@ -2379,13 +2572,12 @@ printstring(FILE *out, Node *string)
++n;
}
}
- n += fprintf(out, "\"");
+ return n + fprintf(out, "\"");
}
int
-printoperant(FILE *out, Node *expr, int opprec, bool braceequalprec, int indent)
-{
+printoperant(FILE *out, Node *expr, int opprec, bool braceequalprec, int indent) {
int prec, n = 0;
if (!expr)
@@ -2410,8 +2602,7 @@ printoperant(FILE *out, Node *expr, int opprec, bool braceequalprec, int indent)
}
int
-printexpr(FILE *out, Node *expr, int indent)
-{
+printexpr(FILE *out, Node *expr, int indent) {
Node *c;
int n = 0;